Compare commits

...

212 Commits
master ... r1.3

Author SHA1 Message Date
  i-robot acc2472c41 !1939 code_sync_0705_inc 4 years ago
  dingpeifei f49a21d293 code_sync_0705_inc 4 years ago
  i-robot 92b997f9f9 !1938 code_sync_0705 4 years ago
  dingpeifei bac5d1b678 Merge remote-tracking branch 'origin/r1.3' into code_sync_0705 4 years ago
  i-robot cc7175217c !1928 cherry-pick fix for dynamic shape V1 4 years ago
  i-robot 230010b770 !1933 parallel group 4 years ago
  陈华 3929578dee fix parallel group 4 years ago
  i-robot d0f986ea46 !1903 fix sc 4 years ago
  i-robot 0684bd48cf !1927 fix safe 4 years ago
  i-robot 2444f46b8d !1926 set size for dynamic input 4 years ago
  i-robot ca24b76141 !1922 FindLastBpFromBpNode c78 4 years ago
  zhangxiaokun 2785670745 fix printf format 4 years ago
  lianghao 2d9e3da649 IsEnterFeedNode 4 years ago
  zhangxiaokun ac1f4eb1c2 DSP: Switch -> TransData -> Cast -> Exit 4 years ago
  zhangxiaokun 61c203619c Remove UT dump env 4 years ago
  zhangxiaokun b23075b62f UT for control flow group 4 years ago
  zhangxiaokun dbc989a4c3 Clear UpdatePersistTensor Warning for first run 4 years ago
  zhangxiaokun 796513222a UpdatePersistTensor from ExecutionEngine 4 years ago
  zhangxiaokun 002583e4ef Fix Set Control flow group for -1 4 years ago
  zhangxiaokun 65cafdd034 Replace MemcpyAsyncNodeTask 4 years ago
  zhangxiaokun 64d312ab12 UT for LaunchKernelCustAicpuSo 4 years ago
  zhangxiaokun ded54e73af Fix Guard for variable release 4 years ago
  zhangxiaokun 7b1331770a Fix multi control from one node 4 years ago
  zhangxiaokun c440980918 Fix BuildPartitionFrame failed 4 years ago
  wangzhengjun 5da987eb3a set size for dynamic input 4 years ago
  lianghao 8c55572e12 FindLastBpFromBpNode c78 4 years ago
  i-robot 656cd3f3d5 !1873 add copy graph 4 years ago
  wangxiaotian22 e5457d5949 fix sc + 4 years ago
  wangxiaotian22 062757756c fix sc 4 years ago
  i-robot 333bbb700a !1885 fix mem leak 4 years ago
  i-robot 9088d5f696 !1871 FillKernel c78 4 years ago
  i-robot 96b0db9bd5 !1877 step info 4 years ago
  lianghao 1e0a3c0bca FillKernel c78 4 years ago
  wuweikang bb2c55fac8 add copy graph 4 years ago
  wuweikang e51ffe2f54 fix mem leak 4 years ago
  i-robot b4b7c454e6 !1882 code_sync_0626_inc 4 years ago
  dingpeifei a8420bd064 code_sync_0626_inc 4 years ago
  i-robot f565c42679 !1880 code_sync_0626 4 years ago
  dingpeifei 3a4112a159 code_sync_0626 4 years ago
  mindspore-ci-bot e0ee4afa7d !1878 update submodule 4 years ago
  zhupuxu be6b3a176f step info 4 years ago
  mindspore-ci-bot ee744e6bae !1872 Create NodeExecute on-demand 4 years ago
  mindspore-ci-bot 17370647b4 !1870 update protobuf to 3.13.0 4 years ago
  李磊 0cf2a59463 update version of protobuf to v3.13.0 4 years ago
  chuxing e2cad9c2ec fixed ad3e707 from https://gitee.com/mindspore/graphengine/pulls/1821 4 years ago
  mindspore-ci-bot dd2ba23718 !1867 add atc_params: check_report for ConvertModelToJson 4 years ago
  mindspore-ci-bot 4d9fe5505a !1862 add op_precision_mode option and support op_debug_level = 4 4 years ago
  mindspore-ci-bot 5c51e07d61 !1864 Fix mem leak and recursive depth protection. 4 years ago
  lichun 6fd9337505 add atc_params: check_report for ConvertModelToJson 4 years ago
  zhaozhixuan 2a0a6eaf2c Fix bug. 4 years ago
  zhaozhixuan cd9869c99d Fix bug. 4 years ago
  mindspore-ci-bot 122a05cda3 !1863 fix opt info 4 years ago
  zhaozhixuan a562b4b6be Fix mem leak and recursive depth protection. 4 years ago
  zhaozhixuan a431199716 Fix mem leak and recursive depth protection. 4 years ago
  陈华 d3bda362d9 fix opt info 4 years ago
  mindspore-ci-bot 83dffb39f6 !1859 skip control flow op when replace node with empty tensor 4 years ago
  mindspore-ci-bot 26fc202953 !1848 add op_precision_mode option and support op_debug_level = 4 4 years ago
  wangzhengjun 49c2eadb7c skip control flow op when replace node with empty tensor 4 years ago
  mindspore-ci-bot 3aec2fa1c8 !1850 fix cmetric 4 years ago
  mindspore-ci-bot f410cde8f2 !1835 move opt to ge_compile 4 years ago
  陈华 310959e5d9 move to ge_compile 4 years ago
  mindspore-ci-bot 802f671006 !1826 opt info 4 years ago
  陈华 bba62ec5f3 opt_info 4 years ago
  wqtshg 0854139134 update submodule 4 years ago
  mindspore-ci-bot 9084063b48 !1818 回退 'Pull Request !1806 : add header targets for link' 4 years ago
  mindspore-ci-bot 287d6db86f !1817 update submodule file 4 years ago
  王涛 a70ab7c91a 回退 'Pull Request !1806 : add header targets for link' 4 years ago
  王涛 65ba89fb34 update .gitmodules. 4 years ago
  计晨 3e510ffc8d !1815 回退 'Pull Request !1784 : Create NodeExecute on-demand' 4 years ago
  储星 c2a1076a87 回退 'Pull Request !1784 : Create NodeExecute on-demand' 4 years ago
  mindspore-ci-bot 3e57016ffa !1806 add header targets for link 4 years ago
  mindspore-ci-bot 2d5ad3a72e !1808 code_sync_0617_update 4 years ago
  mindspore-ci-bot 076e4222ed !1807 回退 'Pull Request ls : Adaptation rectification of op_tiling.' 4 years ago
  dingpeifei ec89f88f1f code_sync_0617_update 4 years ago
  zhaozhixuan 676ce23b55 回退 'Pull Request ls : Adaptation rectification of op_tiling.' 4 years ago
  wangkai 6f130e2290 add link header targets 4 years ago
  mindspore-ci-bot 002f4da441 !1803 code_sync_0617_inc 4 years ago
  dingpeifei ac2ce17a9a code_sync_0617_inc 4 years ago
  zhangzhenghai f71c51e18c !1798 code_sync_0617 4 years ago
  mindspore-ci-bot 7610fa5393 !1712 Adaptation rectification of op_tiling. 4 years ago
  dingpeifei da8b5170e7 code_sync_0617_1 4 years ago
  mindspore-ci-bot 1e9558e8dd !1784 Create NodeExecute on-demand 4 years ago
  zhaozhixuan fd51637c46 Fix zip bug. 4 years ago
  zhaozhixuan bd1beee90c Fix zip bug. 4 years ago
  mindspore-ci-bot caba8dbf6b !1796 bugfix for restore context 4 years ago
  mindspore-ci-bot b59fbaca6b !1790 fix sc 4 years ago
  zhaozhixuan 9476853d22 Adaptation rectification of op_tiling. 4 years ago
  mindspore-ci-bot b0a017d406 !1785 Optimize performance of single_op executor. 4 years ago
  mindspore-ci-bot 535b674d1a !1791 modify dump config 4 years ago
  mindspore-ci-bot c173f92bcc !1797 Remove reduplicated useless proto 4 years ago
  zhangxiaokun 1bed26c72e Remove reduplicated useless proto 4 years ago
  mindspore-ci-bot 82d489a5e6 !1794 update submodule metadef 4 years ago
  y00500818 246d7e4fd8 bugfix for restore context 4 years ago
  wq160 5bcb04dfb7 update submodule 4 years ago
  zhaozhixuan 4bc0f6f2af Fix bug. 4 years ago
  zhaozhixuan b17eafe3db Fix bug. 4 years ago
  zhaozhixuan 23c8a0d581 Fix ut. 4 years ago
  mindspore-ci-bot c4610153e6 !1787 ge code for 1981 4 years ago
  zhou_chao1993 6927a8eef3 modif dump config 4 years ago
  zhou_lili 116167dc88 ge code for 1981 4 years ago
  zhaozhixuan 58086ab187 Release mem. 4 years ago
  zhaozhixuan 69da59b6b7 Fix ut. 4 years ago
  zhaozhixuan d0e83c26a7 Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 4 years ago
  zhaozhixuan d1eba02e1e Fix ut. 4 years ago
  xchu42 b64048a39f Init NodeExecutor on demand 4 years ago
  zhaozhixuan 24eedfa3b4 Fix ut. 4 years ago
  zhaozhixuan 7ce31b2e0e Fix ut. 4 years ago
  zhaozhixuan 0c2d07eb72 Fix ut. 4 years ago
  zhaozhixuan e765135b86 Merge branch 'my_dev4' of https://gitee.com/zhao_zhixuan/graphengine into my_dev4 4 years ago
  zhaozhixuan 492d36b237 Fix ut. 4 years ago
  zhaozhixuan 3a54dc6dd8 Merge https://gitee.com/mindspore/graphengine into my_dev4 4 years ago
  zhaozhixuan ab7334ed78 Release context in execute end. 4 years ago
  mindspore-ci-bot f7370bc074 !1788 enable optimization 4 years ago
  mindspore-ci-bot 78bbbda571 !1783 Fix dynamic shape partition 4 years ago
  zhaozhixuan 181cd5891b Release context in execute end. 4 years ago
  wangzhengjun 367774c5b0 enable optimization 4 years ago
  zhangxiaokun f578e8fff4 Fix NodeState for UT 4 years ago
  zhangxiaokun ab65075326 Add Init to NodeState 4 years ago
  zhangxiaokun 8852766766 Fix hccl_node_executor_unittest 4 years ago
  zhangxiaokun e85bbe2181 Fix dynamic shape partition 4 years ago
  mindspore-ci-bot 5393b67679 !1774 ut fix 4 years ago
  zhaozhixuan b35412f5ea Add ut. 4 years ago
  zhaozhixuan 1ab9ae32dc Add ut. 4 years ago
  zhaozhixuan 13c98395e2 Add ut. 4 years ago
  zhaozhixuan 4c3c819129 Optimize performance of single_op executor. 4 years ago
  mindspore-ci-bot 83421b6c16 !1775 run_flag 4 years ago
  lianghao 1847efba53 run_flag 4 years ago
  mindspore-ci-bot 38bea2e8a6 !1780 set step profiling 4 years ago
  zhupuxu 0be66d0dca set step 4 years ago
  zhengyuanhua d8ba1fb2c0 remove graph ut form ge 4 years ago
  wjm eee4d7b492 fix safe 4 years ago
  wjm 9e33784220 Merge branch 'master' of gitee.com:jiming6/graphengine 4 years ago
  wjm 2abf8be621 fix sc 4 years ago
  wjm f826880f75 fix sc 4 years ago
  mindspore-ci-bot 3442c74c1e !1748 code_sync_0604_rt_ffts.h 4 years ago
  dingpeifei bc93b5b742 code_sync_0604 4 years ago
  mindspore-ci-bot 94b64c3628 !1744 code_sync_0604 4 years ago
  dingpeifei ba20c7d22c code_sync_0604_inc 4 years ago
  mindspore-ci-bot 61ef430599 !1742 code_sync_0604 4 years ago
  dingpeifei 267878bf95 code_sync_0604 4 years ago
  mindspore-ci-bot c5be7f3198 !1677 code_sync_0520_inc 4 years ago
  dingpeifei 2089d50d14 run ge head 4 years ago
  mindspore-ci-bot 840fa69b81 !1674 code_sync_0520 4 years ago
  dingpeifei 55059dae11 Merge remote-tracking branch 'origin/release' into code_sync_0520 4 years ago
  mindspore-ci-bot f1f6a01ffc !1603 code_sync_0428_inc 4 years ago
  dingpeifei 452ec62cb2 code_sync_0428_inc 4 years ago
  mindspore-ci-bot fcf0e310ae !1600 code_sync_0428 4 years ago
  dingpeifei 38621eb9ed code_sync_0428 4 years ago
  dingpeifei 600105f297 Merge remote-tracking branch 'origin/release' into code_sync_0428 4 years ago
  mindspore-ci-bot 708f523f05 !1561 code_sync_0420_inc 4 years ago
  dingpeifei 5b9393e323 code_sync_0420_inc 4 years ago
  mindspore-ci-bot 10b87bb2cf !1556 code_sync_0420 4 years ago
  dingpeifei 4d0b6f9d9a code_sync_0420 4 years ago
  mindspore-ci-bot 8258636925 !1555 code_sync_0420 4 years ago
  dingpeifei 55f1e60969 code_sync_0420 4 years ago
  dingpeifei 656d909b35 Merge remote-tracking branch 'origin/release' into code_sync_0420 4 years ago
  mindspore-ci-bot 40e5c42a12 !1267 workspace of comm op can be reused 4 years ago
  zhoufeng d6308151e0 reuse workspace memory of hccl op 4 years ago
  mindspore-ci-bot f65be61197 !1250 update ge 4 years ago
  shenwei41 6a82dd1947 update ge 4 years ago
  mindspore-ci-bot 6ce82eff9b !1242 update ge Compile Error 4 years ago
  shenwei41 948000fe5b fix ge Compile Error 4 years ago
  mindspore-ci-bot c0f3dcb4f4 !1241 update include header files 0311 4 years ago
  shenwei41 54a48678ae update include headers 0311 4 years ago
  mindspore-ci-bot 8737b1843d !1234 update ge_0311 4 years ago
  shenwei41 10faa7b052 change 4 years ago
  shenwei41 5dd83a9abd Merge remote-tracking branch 'origin/release' into code_sync_0311 4 years ago
  mindspore-ci-bot 92286b21ec !1155 update metadef 4 years ago
  shenwei41 a1795d6554 update metdef 4 years ago
  mindspore-ci-bot f37d94ba21 !1153 update_headerfiles_0225 4 years ago
  shenwei41 3e5f9dac74 update include headers_0225 4 years ago
  mindspore-ci-bot d487b4781d !1150 update graphengine_0225 4 years ago
  shenwei41 0ebcb55f3f Merge remote-tracking branch 'origin/release' into code_sync_0225 4 years ago
  mindspore-ci-bot 210a007d8e !1139 update_headerfiles_0222 4 years ago
  shenwei41 cc456d5803 update header files 0222 4 years ago
  mindspore-ci-bot b9050a8c51 !1133 update graphengine_0222 4 years ago
  shenwei41 028f47826e Merge remote-tracking branch 'origin/release' into code_sync_0222 4 years ago
  mindspore-ci-bot 6ebc4745de !1104 modify pkg lib 4 years ago
  changzherui aa4045b31d modify pkg lib 4 years ago
  mindspore-ci-bot 8dc712ca01 !1096 update 4 years ago
  shenwei41 de4224d3ba update 4 years ago
  mindspore-ci-bot d28090d511 !1095 update metadef 4 years ago
  shenwei41 543fd0a0e8 update metadef 4 years ago
  mindspore-ci-bot 06894dcd48 !1093 update prebuild 4 years ago
  shenwei41 29e380abf6 update prebuild 4 years ago
  mindspore-ci-bot e7036bb984 !1079 update include file 4 years ago
  shenwei41 cffc6b2e2e update include file 4 years ago
  mindspore-ci-bot ea0e2eadad !1071 update commit id 4 years ago
  shenwei41 c332519b2e Merge remote-tracking branch 'origin/release' into code_sync_0203 4 years ago
  mindspore-ci-bot 08e4e60663 !1056 modify error_codes.h 4 years ago
  changzherui b4539d54cd modify error_codes.h 4 years ago
  mindspore-ci-bot 2d95c17f11 !1040 sync code 0126 4 years ago
  changzherui 46dbe7a2fc Merge remote-tracking branch 'upstream/release' into code_sync_126 4 years ago
  mindspore-ci-bot dfa4dd4acd !1024 Synchronization code .h 0125 4 years ago
  changzherui ed7e35f927 sync code 0125 .h 4 years ago
  mindspore-ci-bot 7183c03452 !1022 Synchronization code 0125 4 years ago
  changzherui a8b7570e71 Merge remote-tracking branch 'upstream/release' into code_sync_0125 4 years ago
  mindspore-ci-bot 687d5f2ab1 !966 sync code 0116 h 4 years ago
  yanghaoran e3b32cd2a0 clang-format 4 years ago
  changzherui 411e71f1f3 sync code h 4 years ago
  mindspore-ci-bot 50f17e37d9 !963 code sync 0116 4 years ago
  changzherui 2b200b25ed sync code 0116 4 years ago
  mindspore-ci-bot 20a0326976 !567 prioritize json downloading from gitee 4 years ago
  yanghaoran d77f36e017 prioritize json downloading from gitee 4 years ago
  mindspore-ci-bot 75572bb987 !562 Synchronize latest Ascend software suite 09 Dec 2020 4 years ago
  yanghaoran d5a82a7f98 Synchronize latest Ascend software suite 09 Dec 2020 4 years ago
  mindspore-ci-bot 24b53b9282 !546 fix geruntime missing files and error codes 4 years ago
  yanghaoran d731918198 fix geruntime missing files and error codes 4 years ago
  mindspore-ci-bot 8712387b3c !533 update headers for release branch, for mindspore use 4 years ago
  mindspore-ci-bot 22ac2e8c14 !537 find libraries from both atc and fwk paths 4 years ago
  yanghaoran fec2e70eda find libraries from both atc and fwk paths 4 years ago
  yanghaoran 8011e1ea9a update headers 4 years ago
100 changed files with 1604 additions and 4464 deletions
Unified View
  1. +2
    -2
      .gitmodules
  2. +3
    -3
      CMakeLists.txt
  3. +4
    -4
      build.sh
  4. +0
    -4
      cmake/external_libs/json.cmake
  5. +6
    -6
      cmake/external_libs/protobuf_shared.cmake
  6. +5
    -7
      cmake/external_libs/protobuf_static.cmake
  7. +5
    -7
      cmake/external_libs/protoc.cmake
  8. +10
    -0
      ge/CMakeLists.txt
  9. +0
    -1
      ge/client/proto/ge_api.proto
  10. +0
    -193
      ge/client/proto/ge_ir.proto
  11. +0
    -140
      ge/client/proto/insert_op.proto
  12. +0
    -396
      ge/client/proto/om.proto
  13. +0
    -179
      ge/client/proto/task.proto
  14. +1
    -0
      ge/common/CMakeLists.txt
  15. +2
    -2
      ge/common/dump/dump_manager.cc
  16. +1
    -0
      ge/common/dump/exception_dumper.cc
  17. +10
    -2
      ge/common/ge/tbe_plugin_manager.cc
  18. +2
    -1
      ge/common/ge/tbe_plugin_manager.h
  19. +38
    -0
      ge/common/profiling/ge_profiling.cc
  20. +0
    -193
      ge/common/proto/ge_ir.proto
  21. +0
    -140
      ge/common/proto/insert_op.proto
  22. +0
    -396
      ge/common/proto/om.proto
  23. +0
    -75
      ge/common/proto/op_mapping.proto
  24. +0
    -179
      ge/common/proto/task.proto
  25. +0
    -70
      ge/common/proto/tensorflow/attr_value.proto
  26. +0
    -108
      ge/common/proto/tensorflow/function.proto
  27. +0
    -64
      ge/common/proto/tensorflow/graph.proto
  28. +0
    -22
      ge/common/proto/tensorflow/graph_library.proto
  29. +0
    -71
      ge/common/proto/tensorflow/node_def.proto
  30. +0
    -172
      ge/common/proto/tensorflow/op_def.proto
  31. +0
    -37
      ge/common/proto/tensorflow/resource_handle.proto
  32. +0
    -102
      ge/common/proto/tensorflow/tensor.proto
  33. +0
    -53
      ge/common/proto/tensorflow/tensor_shape.proto
  34. +0
    -82
      ge/common/proto/tensorflow/types.proto
  35. +0
    -39
      ge/common/proto/tensorflow/versions.proto
  36. +1
    -0
      ge/executor/CMakeLists.txt
  37. +0
    -113
      ge/executor/proto/dump_task.proto
  38. +0
    -193
      ge/executor/proto/ge_ir.proto
  39. +0
    -140
      ge/executor/proto/insert_op.proto
  40. +0
    -396
      ge/executor/proto/om.proto
  41. +0
    -75
      ge/executor/proto/op_mapping.proto
  42. +0
    -179
      ge/executor/proto/task.proto
  43. +4
    -7
      ge/ge_local_engine/engine/host_cpu_engine.cc
  44. +0
    -179
      ge/ge_local_engine/proto/task.proto
  45. +58
    -0
      ge/ge_opt_info/ge_opt_info.cc
  46. +31
    -0
      ge/ge_opt_info/ge_opt_info.h
  47. +1
    -0
      ge/ge_runtime/CMakeLists.txt
  48. +1
    -0
      ge/ge_runtime/runtime_model.cc
  49. +2
    -14
      ge/ge_runtime/task/hccl_task.cc
  50. +42
    -16
      ge/ge_runtime/task/label_goto_task.cc
  51. +10
    -6
      ge/ge_runtime/task/label_goto_task.h
  52. +119
    -0
      ge/ge_runtime/task/label_manager.cc
  53. +54
    -0
      ge/ge_runtime/task/label_manager.h
  54. +11
    -14
      ge/ge_runtime/task/label_switch_task.cc
  55. +4
    -2
      ge/ge_runtime/task/label_switch_task.h
  56. +6
    -5
      ge/generator/ge_generator.cc
  57. +5
    -0
      ge/graph/build/label_allocator.cc
  58. +5
    -0
      ge/graph/build/logical_stream_allocator.cc
  59. +1
    -1
      ge/graph/build/model_builder.cc
  60. +10
    -1
      ge/graph/build/stream_allocator.cc
  61. +47
    -29
      ge/graph/build/task_generator.cc
  62. +2
    -1
      ge/graph/build/task_generator.h
  63. +0
    -15
      ge/graph/common/omg_util.cc
  64. +0
    -9
      ge/graph/common/omg_util.h
  65. +158
    -49
      ge/graph/load/model_manager/davinci_model.cc
  66. +6
    -0
      ge/graph/load/model_manager/davinci_model.h
  67. +14
    -9
      ge/graph/load/model_manager/model_manager.cc
  68. +393
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.cc
  69. +66
    -0
      ge/graph/load/model_manager/task_info/ffts_task_info.h
  70. +2
    -0
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  71. +20
    -26
      ge/graph/manager/graph_manager.cc
  72. +1
    -1
      ge/graph/manager/graph_manager.h
  73. +0
    -2
      ge/graph/optimize/graph_optimize.cc
  74. +29
    -7
      ge/graph/partition/dynamic_shape_partition.cc
  75. +4
    -2
      ge/graph/partition/dynamic_shape_partition.h
  76. +14
    -7
      ge/graph/partition/graph_partition.cc
  77. +63
    -72
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  78. +11
    -0
      ge/graph/passes/mark_force_unknown_for_cond_pass.h
  79. +6
    -0
      ge/graph/passes/mark_graph_unknown_status_pass.cc
  80. +2
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  81. +23
    -6
      ge/graph/passes/next_iteration_pass.cc
  82. +39
    -19
      ge/graph/passes/parallel_group_pass.cc
  83. +1
    -0
      ge/graph/passes/parallel_group_pass.h
  84. +20
    -0
      ge/graph/passes/replace_with_empty_const_pass.cc
  85. +9
    -8
      ge/graph/passes/switch_to_stream_switch_pass.cc
  86. +66
    -53
      ge/graph/preprocess/graph_preprocess.cc
  87. +2
    -1
      ge/graph/preprocess/graph_preprocess.h
  88. +1
    -0
      ge/graph/preprocess/insert_op/util_insert_aipp_op.cc
  89. +1
    -1
      ge/graph/preprocess/multi_batch_copy_graph.cc
  90. +8
    -0
      ge/host_kernels/fill_kernel.cc
  91. +5
    -3
      ge/hybrid/executor/hybrid_model_async_executor.cc
  92. +8
    -7
      ge/hybrid/executor/hybrid_model_executor.cc
  93. +2
    -1
      ge/hybrid/executor/hybrid_model_executor.h
  94. +79
    -8
      ge/hybrid/executor/node_state.cc
  95. +12
    -8
      ge/hybrid/executor/node_state.h
  96. +19
    -8
      ge/hybrid/executor/subgraph_context.cc
  97. +3
    -2
      ge/hybrid/executor/subgraph_context.h
  98. +11
    -15
      ge/hybrid/executor/subgraph_executor.cc
  99. +2
    -1
      ge/hybrid/executor/subgraph_executor.h
  100. +1
    -0
      ge/hybrid/executor/worker/execution_engine.cc

+ 2
- 2
.gitmodules View File

@@ -1,8 +1,8 @@
[submodule "parser"] [submodule "parser"]
path = parser path = parser
url = https://gitee.com/ascend/parser.git url = https://gitee.com/ascend/parser.git
branch = master
branch = r1.5.0
[submodule "metadef"] [submodule "metadef"]
path = metadef path = metadef
url = https://gitee.com/ascend/metadef.git url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.5.0

+ 3
- 3
CMakeLists.txt View File

@@ -95,6 +95,7 @@ else ()
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else() else()
find_module(slog libalog.so ${ASCEND_ATC_DIR}) find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train") if(PLATFORM STREQUAL "train")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
@@ -128,7 +129,6 @@ else ()
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")
endif() endif()
endif() endif()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser)
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
@@ -161,6 +161,7 @@ else ()
elseif(ENABLE_MS_TESTCASES) elseif(ENABLE_MS_TESTCASES)
include(cmake/external_libs/protobuf_static.cmake) include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake) include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/external_libs/securec.cmake) include(cmake/external_libs/securec.cmake)
include(cmake/FindModule.cmake) include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake) include(cmake/intf_pub_linux.cmake)
@@ -178,5 +179,4 @@ else ()
endif() endif()


add_subdirectory(ge) add_subdirectory(ge)

endif ()
endif ()

+ 4
- 4
build.sh View File

@@ -355,13 +355,13 @@ generate_package()


if [ "x${PLATFORM}" = "xtrain" ] if [ "x${PLATFORM}" = "xtrain" ]
then then
tar -cf graphengine_lib.tar fwkacllib
tar -zcf graphengine_lib.tar fwkacllib
elif [ "x${PLATFORM}" = "xinference" ] elif [ "x${PLATFORM}" = "xinference" ]
then then
tar -cf graphengine_lib.tar acllib atc
tar -zcf graphengine_lib.tar acllib atc
elif [ "x${PLATFORM}" = "xall" ] elif [ "x${PLATFORM}" = "xall" ]
then then
tar -cf graphengine_lib.tar fwkacllib acllib atc
tar -zcf graphengine_lib.tar fwkacllib acllib atc
fi fi
} }


@@ -371,6 +371,6 @@ elif [ "X$MINDSPORE_MODE" = "Xon" ]
then then
cd "${OUTPUT_PATH}" cd "${OUTPUT_PATH}"
find ./ -name graphengine_lib.tar -exec rm {} \; find ./ -name graphengine_lib.tar -exec rm {} \;
tar -cf graphengine_lib.tar lib
tar -zcf graphengine_lib.tar lib
fi fi
echo "---------------- GraphEngine package archive generated ----------------" echo "---------------- GraphEngine package archive generated ----------------"

+ 0
- 4
cmake/external_libs/json.cmake View File

@@ -9,10 +9,6 @@ if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89") set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
#elseif (ENABLE_GITEE)
# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
#set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
else() else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89") set(MD5 "0dc903888211db3a0f170304cd9f3a89")


+ 6
- 6
cmake/external_libs/protobuf_shared.cmake View File

@@ -11,14 +11,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif() endif()
if (GE_PB_PKG) if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -58,7 +58,7 @@ target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/
set(INSTALL_BASE_DIR "") set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib) set(INSTALL_LIBRARY_DIR lib)


install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.13.0.0 OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR}) DESTINATION ${INSTALL_LIBRARY_DIR})
install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR}) DESTINATION ${INSTALL_LIBRARY_DIR})


+ 5
- 7
cmake/external_libs/protobuf_static.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif() endif()


if(GE_PB_PKG) if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -29,8 +29,6 @@ set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build ExternalProject_Add(protobuf_static_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} CONFIGURE_COMMAND ${CMAKE_COMMAND}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}


+ 5
- 7
cmake/external_libs/protoc.cmake View File

@@ -13,14 +13,14 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
endif() endif()


if(GE_PB_PKG) if(GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.8.0.tar.gz")
set(REQ_URL "${GE_PB_PKG}/libs/protobuf/v3.13.0.tar.gz")
else() else()
if (ENABLE_GITEE) if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else() else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif () endif ()
endif() endif()


@@ -28,8 +28,6 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build ExternalProject_Add(protoc_build
URL ${REQ_URL} URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0
TLS_VERIFY OFF TLS_VERIFY OFF
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake
BUILD_COMMAND $(MAKE) BUILD_COMMAND $(MAKE)


+ 10
- 0
ge/CMakeLists.txt View File

@@ -174,6 +174,7 @@ set(TRAIN_SRC_LIST
"graph/load/model_manager/task_info/model_exit_task_info.cc" "graph/load/model_manager/task_info/model_exit_task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/hccl_task_info.cc" "graph/load/model_manager/task_info/hccl_task_info.cc"
@@ -433,6 +434,7 @@ set(TRAIN_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc" "graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc" "graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
) )


set(INFER_SRC_LIST set(INFER_SRC_LIST
@@ -662,6 +664,7 @@ set(INFER_SRC_LIST
"graph/load/model_manager/task_info/task_info.cc" "graph/load/model_manager/task_info/task_info.cc"
"graph/load/model_manager/task_info/event_record_task_info.cc" "graph/load/model_manager/task_info/event_record_task_info.cc"
"graph/load/model_manager/task_info/event_wait_task_info.cc" "graph/load/model_manager/task_info/event_wait_task_info.cc"
"graph/load/model_manager/task_info/ffts_task_info.cc"
"graph/load/model_manager/task_info/fusion_start_task_info.cc" "graph/load/model_manager/task_info/fusion_start_task_info.cc"
"graph/load/model_manager/task_info/fusion_stop_task_info.cc" "graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"graph/load/model_manager/task_info/kernel_ex_task_info.cc" "graph/load/model_manager/task_info/kernel_ex_task_info.cc"
@@ -709,6 +712,7 @@ set(INFER_SRC_LIST
"graph/build/memory/max_block_mem_assigner.cc" "graph/build/memory/max_block_mem_assigner.cc"
"graph/build/memory/var_mem_assign_util.cc" "graph/build/memory/var_mem_assign_util.cc"
"graph/build/memory/buffer_pool_mem_assigner.cc" "graph/build/memory/buffer_pool_mem_assigner.cc"
"ge_opt_info/ge_opt_info.cc"
) )


if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
@@ -770,11 +774,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external ${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone #### blue zone
${ASCEND_DIR}/driver/include ${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include ${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
) )


target_link_options(ge_runner PRIVATE target_link_options(ge_runner PRIVATE
@@ -797,6 +803,7 @@ target_link_libraries(ge_runner PRIVATE
runtime runtime
error_manager error_manager
ascend_hal_stub ascend_hal_stub
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
-lrt -lrt
@@ -851,11 +858,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE
${GE_CODE_DIR}/../inc/cce ${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
${GE_CODE_DIR}/../abl/adump/external ${GE_CODE_DIR}/../abl/adump/external
${GE_CODE_DIR}/../abl/licctrl
#### blue zone #### #### blue zone ####
${ASCEND_DIR}/driver/include ${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include ${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc ${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info
) )


target_link_options(ge_compiler PRIVATE target_link_options(ge_compiler PRIVATE
@@ -875,6 +884,7 @@ target_link_libraries(ge_compiler PRIVATE
error_manager error_manager
slog slog
runtime_compile runtime_compile
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
-lrt -lrt


+ 0
- 1
ge/client/proto/ge_api.proto View File

@@ -1 +0,0 @@
../../proto/ge_api.proto

+ 0
- 193
ge/client/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/client/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/client/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 179
ge/client/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 1
- 0
ge/common/CMakeLists.txt View File

@@ -106,6 +106,7 @@ target_link_libraries(ge_common PRIVATE
c_sec c_sec
error_manager error_manager
slog slog
opt_feature
-Wl,--as-needed -Wl,--as-needed
json json
$<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt> $<$<NOT:$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-lrt>


+ 2
- 2
ge/common/dump/dump_manager.cc View File

@@ -33,7 +33,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpManager &DumpManager::GetIn


bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) { bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump_properties) {
if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) { if (dump_config.dump_status.empty() && dump_config.dump_debug.empty()) {
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
GELOGI("Dump does not open"); GELOGI("Dump does not open");
return false; return false;
} }
@@ -41,7 +41,7 @@ bool DumpManager::NeedDoDump(const DumpConfig &dump_config, DumpProperties &dump
if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) && if ((dump_config.dump_status == kDumpoff || dump_config.dump_status == kDumpOFF) &&
dump_config.dump_debug == kDumpoff) { dump_config.dump_debug == kDumpoff) {
dump_properties.ClearDumpPropertyValue(); dump_properties.ClearDumpPropertyValue();
dump_properties_map_.emplace(kInferSessionId, dump_properties);
dump_properties_map_[kInferSessionId] = dump_properties;
return false; return false;
} }
if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) { if (dump_config.dump_status == kDumpOn && dump_config.dump_debug == kDumpOn) {


+ 1
- 0
ge/common/dump/exception_dumper.cc View File

@@ -161,6 +161,7 @@ Status ExceptionDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> &ex


uint64_t proto_size = dump_data.ByteSizeLong(); uint64_t proto_size = dump_data.ByteSizeLong();
std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); std::unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]);
GE_CHECK_NOTNULL(proto_msg);
bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size);
if (!ret || proto_size == 0) { if (!ret || proto_size == 0) {
REPORT_INNER_ERROR("E19999", "Serialize proto to string fail"); REPORT_INNER_ERROR("E19999", "Serialize proto to string fail");


+ 10
- 2
ge/common/ge/tbe_plugin_manager.cc View File

@@ -104,7 +104,15 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff
} }
} }


void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) {
void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_list,
string &caffe_parser_path, int recursive_depth) {
static const int kMaxRecursiveDepth = 20; // For recursive depth protection

if (recursive_depth >= kMaxRecursiveDepth) {
GELOGW("Recursive depth is become %d, Please check input!", recursive_depth);
return;
}

// Path, change to absolute path // Path, change to absolute path
string real_path = RealPath(path.c_str()); string real_path = RealPath(path.c_str());
// Plugin path does not exist // Plugin path does not exist
@@ -138,7 +146,7 @@ void TBEPluginManager::FindParserSo(const string &path, vector<string> &file_lis
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff,
aicpu_host_so_suff); aicpu_host_so_suff);
} else { } else {
FindParserSo(full_name, file_list, caffe_parser_path);
FindParserSo(full_name, file_list, caffe_parser_path, recursive_depth + 1);
} }
} }
mmScandirFree(entries, ret); mmScandirFree(entries, ret);


+ 2
- 1
ge/common/ge/tbe_plugin_manager.h View File

@@ -57,7 +57,8 @@ class TBEPluginManager {
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff, const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff); const string &aicpu_host_so_suff);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path,
int recursive_depth = 0);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void GetCustomOpPath(std::string &customop_path); static void GetCustomOpPath(std::string &customop_path);
void LoadCustomOpLib(); void LoadCustomOpLib();


+ 38
- 0
ge/common/profiling/ge_profiling.cc View File

@@ -22,6 +22,7 @@
#include "graph/load/graph_loader.h" #include "graph/load/graph_loader.h"
#include "init/gelib.h" #include "init/gelib.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "model/ge_model.h"


namespace { namespace {
const uint32_t kDeviceListIndex = 3; const uint32_t kDeviceListIndex = 3;
@@ -42,6 +43,10 @@ const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = {
{kProfCommandhandleFinalize, kProfilingFinalize}, {kProfCommandhandleFinalize, kProfilingFinalize},
{kProfCommandhandleModelSubscribe, kProfModelSubscribe}, {kProfCommandhandleModelSubscribe, kProfModelSubscribe},
{kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}};

const uint64_t kModelId = ge::INVALID_MODEL_ID;
const uint16_t kStepStart = 0;
const uint16_t kStepEnd = 1;
} // namespace } // namespace


bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) { bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector<string> &prof_config_params) {
@@ -216,3 +221,36 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le
return ge::SUCCESS; return ge::SUCCESS;
} }


ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) {
static bool is_first_run = true;
int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(rt_ret, "[Get][LogicDeviceId]Failed, ret 0x%X", rt_ret);
REPORT_CALL_ERROR("E19999", "Get logic device id failed, ret 0x%X", rt_ret);
return ge::FAILED;
}
if (is_first_run && tag_id == kStepStart) {
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id,
kModelId,
tag_id,
stream,
device_id));
is_first_run = false;
return ge::SUCCESS;
}
if (!is_first_run && tag_id == kStepEnd) {
GE_CHK_STATUS_RET_NOLOG(ge::ProfilingManager::Instance().ProfileStepInfo(index_id,
kModelId,
tag_id,
stream,
device_id));
is_first_run = true;
return ge::SUCCESS;
}
GELOGE(ge::FAILED, "Param tag_id:%u invalid when is_first_run is %d", tag_id, is_first_run);
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"value", "parameter", "reason"}),
std::vector<std::string>({std::to_string(tag_id), "tag_id",
"tag id must be 0 when first run, must be 1 when second run"}));
return ge::FAILED;
}

+ 0
- 193
ge/common/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/common/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/common/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/common/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/common/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 0
- 70
ge/common/proto/tensorflow/attr_value.proto View File

@@ -1,70 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}

+ 0
- 108
ge/common/proto/tensorflow/function.proto View File

@@ -1,108 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list

// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.

// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;

// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}

+ 0
- 64
ge/common/proto/tensorflow/graph.proto View File

@@ -1,64 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "node_def.proto";
import "function.proto";
import "versions.proto";

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;

// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;

// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];

// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};

+ 0
- 22
ge/common/proto/tensorflow/graph_library.proto View File

@@ -1,22 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;

import "graph.proto";

message GeGraphDef {
string name = 1;
GraphDef graph = 2;
}

message GraphDefLibrary {
repeated GeGraphDef graph_def = 1;
};

+ 0
- 71
ge/common/proto/tensorflow/node_def.proto View File

@@ -1,71 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

+ 0
- 172
ge/common/proto/tensorflow/op_def.proto View File

@@ -1,172 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "types.proto";

// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};

// Description of the input(s).
repeated ArgDef input_arg = 2;

// Description of the output(s).
repeated ArgDef output_arg = 3;

// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;

// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;

// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;

// Human-readable description.
string description = 4;


// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.

// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;

// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;

// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;

// One-line human-readable description of what the Op does.
string summary = 5;

// Additional, longer human-readable description of what the Op does.
string description = 6;

// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.

// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;

// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
bool is_aggregate = 16; // for things like add

// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.

// -------------------------------------------------------------------------
// Optimization constraints.

// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue

// -------------------------------------------------------------------------
// Non-standard options.

// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)

// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;

// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};

// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};

+ 0
- 37
ge/common/proto/tensorflow/resource_handle.proto View File

@@ -1,37 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;

// Container in which this resource is placed.
string container = 2;

// Unique name of this resource.
string name = 3;

// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;

// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};

+ 0
- 102
ge/common/proto/tensorflow/tensor.proto View File

@@ -1,102 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;

// Shape of the tensor.
TensorShapeProto tensor_shape = 2;

// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.

// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;

// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;

// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.

// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];

// DT_FLOAT.
repeated float float_val = 5 [packed = true];

// DT_DOUBLE.
repeated double double_val = 6 [packed = true];

// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];

// DT_STRING
repeated bytes string_val = 8;

// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];

// DT_INT64
repeated int64 int64_val = 10 [packed = true];

// DT_BOOL
repeated bool bool_val = 11 [packed = true];

// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];

// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;

// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;

// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];

// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}

+ 0
- 53
ge/common/proto/tensorflow/tensor_shape.proto View File

@@ -1,53 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

// Protocol buffer representing the shape of tensors.

syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

package domi.tensorflow;

// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;

// Optional name of the tensor dimension.
string name = 2;
};

// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;

// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};

+ 0
- 82
ge/common/proto/tensorflow/types.proto View File

@@ -1,82 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;

// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;

// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)

+ 0
- 39
ge/common/proto/tensorflow/versions.proto View File

@@ -1,39 +0,0 @@
/**
* This file is part of Open Source Software TensorFlow, version 1.15.0 https://github.com/tensorflow/tensorflow
*
* This file is included by GraphEngine so as to support model format conversion from tensorflow model to GraphEngine model.
* This file in this distribution may have been modified by Huawei Technologies Co., Ltd ("Huawei Modifications").
* All Huawei Modifications are Copyright 2019-2020 Huawei Technologies Co., Ltd.
*/

syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;

// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;

// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};

+ 1
- 0
ge/executor/CMakeLists.txt View File

@@ -37,6 +37,7 @@ set(SRC_LIST
"../graph/load/model_manager/task_info/task_info.cc" "../graph/load/model_manager/task_info/task_info.cc"
"../graph/load/model_manager/task_info/event_record_task_info.cc" "../graph/load/model_manager/task_info/event_record_task_info.cc"
"../graph/load/model_manager/task_info/event_wait_task_info.cc" "../graph/load/model_manager/task_info/event_wait_task_info.cc"
"../graph/load/model_manager/task_info/ffts_task_info.cc"
"../graph/load/model_manager/task_info/fusion_start_task_info.cc" "../graph/load/model_manager/task_info/fusion_start_task_info.cc"
"../graph/load/model_manager/task_info/fusion_stop_task_info.cc" "../graph/load/model_manager/task_info/fusion_stop_task_info.cc"
"../graph/load/model_manager/task_info/kernel_ex_task_info.cc" "../graph/load/model_manager/task_info/kernel_ex_task_info.cc"


+ 0
- 113
ge/executor/proto/dump_task.proto View File

@@ -1,113 +0,0 @@
syntax = "proto3";
package toolkit.dump;

enum OutputDataType {
DT_UNDEFINED = 0;
DT_FLOAT = 1;
DT_FLOAT16 = 2;
DT_INT8 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_UINT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT32 = 9;
DT_UINT64 = 10;
DT_BOOL = 11;
DT_DOUBLE = 12;
DT_STRING = 13;
DT_DUAL_SUB_INT8 = 14;
DT_DUAL_SUB_UINT8 = 15;
DT_COMPLEX64 = 16;
DT_COMPLEX128 = 17;
DT_QINT8 = 18;
DT_QINT16 = 19;
DT_QINT32 = 20;
DT_QUINT8 = 21;
DT_QUINT16 = 22;
DT_RESOURCE = 23;
DT_STRING_REF = 24;
DT_DUAL = 25;
DT_VARIANT = 26;
}

enum OutputFormat {
FORMAT_NCHW = 0;
FORMAT_NHWC = 1;
FORMAT_ND = 2;
FORMAT_NC1HWC0 = 3;
FORMAT_FRACTAL_Z = 4;
FORMAT_NC1C0HWPAD = 5;
FORMAT_NHWC1C0 = 6;
FORMAT_FSR_NCHW = 7;
FORMAT_FRACTAL_DECONV = 8;
FORMAT_C1HWNC0 = 9;
FORMAT_FRACTAL_DECONV_TRANSPOSE = 10;
FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11;
FORMAT_NC1HWC0_C04 = 12;
FORMAT_FRACTAL_Z_C04 = 13;
FORMAT_CHWN = 14;
FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15;
FORMAT_HWCN = 16;
FORMAT_NC1KHKWHWC0 = 17;
FORMAT_BN_WEIGHT = 18;
FORMAT_FILTER_HWCK = 19;
FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20;
FORMAT_HASHTABLE_LOOKUP_KEYS = 21;
FORMAT_HASHTABLE_LOOKUP_VALUE = 22;
FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23;
FORMAT_HASHTABLE_LOOKUP_HITS=24;
FORMAT_C1HWNCoC0 = 25;
FORMAT_MD = 26;
FORMAT_NDHWC = 27;
FORMAT_FRACTAL_ZZ = 28;
FORMAT_FRACTAL_NZ = 29;
FORMAT_RESERVED = 30;
}

message OriginalOp {
string name = 1;
uint32 output_index = 2;
OutputDataType data_type = 3;
OutputFormat format = 4;
}

message Shape {
repeated uint64 dim = 1;
}

message OpOutput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
OriginalOp original_op = 4; // the original op corresponding to the output
bytes data = 5;
uint64 size = 6;
}

message OpInput {
OutputDataType data_type = 1;
OutputFormat format = 2;
Shape shape = 3;
bytes data = 4;
uint64 size = 5;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
bytes data = 2;
uint64 size = 3;
}

message DumpData{
string version = 1;
uint64 dump_time = 2;
repeated OpOutput output = 3;
repeated OpInput input = 4;
repeated OpBuffer buffer = 5;
string op_name = 6;
}

+ 0
- 193
ge/executor/proto/ge_ir.proto View File

@@ -1,193 +0,0 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
DT_VARIANT = 26; // variant type
DT_BF16 = 27; // bf16 type
DT_INT4 = 28; // int4 type
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 0
- 140
ge/executor/proto/insert_op.proto View File

@@ -1,140 +0,0 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;
float padding_value = 72;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 0
- 396
ge/executor/proto/om.proto View File

@@ -1,396 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 0
- 75
ge/executor/proto/op_mapping.proto View File

@@ -1,75 +0,0 @@
syntax = "proto3";
package toolkit.aicpu.dump;

message Shape {
repeated uint64 dim = 1;
}

message Output {
int32 data_type = 1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
string original_name = 5;
int32 original_output_index = 6;
int32 original_output_data_type = 7;
int32 original_output_format = 8;
uint64 size = 9;
Shape origin_shape = 10;
}

message Input {
int32 data_type =1;
int32 format = 2;
Shape shape = 3;
uint64 address = 4;
uint64 size = 5;
Shape origin_shape = 6;
}

enum BufferType {
L1 = 0;
}

message OpBuffer {
BufferType buffer_type = 1;
uint64 address = 2;
uint64 size = 3;
}

message Op {
string op_name = 1;
string op_type = 2;
}

message Task {
uint32 task_id = 1;
uint32 stream_id = 2;
Op op = 3;
repeated Output output = 4;
bool end_graph = 5;
repeated Input input = 6;
repeated OpBuffer buffer = 7;
}

message OpMappingInfo {
string dump_path = 1;
oneof model_name_param {
string model_name = 2;
}
oneof model_id_param {
uint32 model_id = 3;
}
oneof step_id {
uint64 step_id_addr = 4;
}
oneof iterations_per_loop {
uint64 iterations_per_loop_addr = 5;
}
oneof loop_cond {
uint64 loop_cond_addr = 6;
}
uint32 flag = 7; // 0x01 load, 0x00 unload
repeated Task task = 8;
string dump_step = 9;
}

+ 0
- 179
ge/executor/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 4
- 7
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -13,15 +13,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "host_cpu_engine.h"
#include "graph/common/omg_util.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "register/op_kernel_registry.h" #include "register/op_kernel_registry.h"
#include "register/host_cpu_context.h" #include "register/host_cpu_context.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/ge/plugin_manager.h" #include "common/ge/plugin_manager.h"
#include "graph/utils/type_utils.h"
#include "common/fp16_t.h" #include "common/fp16_t.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"


@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) {
} }


Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) {
std::string op_type;
auto status = GetOriginalType(node, op_type);
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status);

const std::string op_type = NodeUtils::GetNodeType(node);
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type);
if (kernel == nullptr) { if (kernel == nullptr) {
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str());


+ 0
- 179
ge/ge_local_engine/proto/task.proto View File

@@ -1,179 +0,0 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 58
- 0
ge/ge_opt_info/ge_opt_info.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_opt_info/ge_opt_info.h"

#include <string>
#include <map>
#include "graph/ge_local_context.h"
#include "ge/ge_api_types.h"
#include "common/debug/ge_log.h"
#include "opt_info.h"

namespace ge {
Status GeOptInfo::SetOptInfo() {
std::string soc_ver;
graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver);
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get soc version failed.");
GELOGE(FAILED, "[Get][SocVersion]Get soc version failed.");
return FAILED;
}
GELOGD("Soc version:%s.", soc_ver.c_str());
std::map<std::string, std::string> opt_info;
// the first arg does not work at present.
if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) {
REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s",
gelc::kOffline, soc_ver.c_str());
return FAILED;
}
// do nothing if get empty information
if (opt_info.empty()) {
GELOGI("Optional information is empty.");
return SUCCESS;
}
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions();
for (const auto &itr : opt_info) {
graph_options.emplace(itr.first, itr.second);
GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str());
}
GetThreadLocalContext().SetGraphOption(graph_options);
return SUCCESS;
}
} // namespace ge

+ 31
- 0
ge/ge_opt_info/ge_opt_info.h View File

@@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_OPT_INFO_GE_OPT_INFO_H_
#define GE_OPT_INFO_GE_OPT_INFO_H_

#include "ge/ge_api_error_codes.h"
#include "register/register_types.h"

namespace ge {
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo {
public:
GeOptInfo() = default;
static Status SetOptInfo();
};
} // namespace ge

#endif // GE_OPT_INFO_GE_OPT_INFO_H_

+ 1
- 0
ge/ge_runtime/CMakeLists.txt View File

@@ -16,6 +16,7 @@ set(GE_SRC_LIST
"task/label_goto_task.cc" "task/label_goto_task.cc"
"task/label_set_task.cc" "task/label_set_task.cc"
"task/label_switch_task.cc" "task/label_switch_task.cc"
"task/label_manager.cc"
) )


add_library(ge_runtime SHARED ${GE_SRC_LIST}) add_library(ge_runtime SHARED ${GE_SRC_LIST})


+ 1
- 0
ge/ge_runtime/runtime_model.cc View File

@@ -21,6 +21,7 @@
#include "common/ge_inner_error_codes.h" #include "common/ge_inner_error_codes.h"
#include "common/types.h" #include "common/types.h"
#include "common/util.h" #include "common/util.h"
#include "common/math/math_util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/op/op_parser_util.h" #include "framework/common/op/op_parser_util.h"
#include "graph/types.h" #include "graph/types.h"


+ 2
- 14
ge/ge_runtime/task/hccl_task.cc View File

@@ -52,15 +52,7 @@ HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<Hccl
} }
} }


HcclTask::~HcclTask() {
if (workspace_mem_ != nullptr) {
rtError_t rt_ret = rtFree(workspace_mem_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret);
}
workspace_mem_ = nullptr;
}
}
HcclTask::~HcclTask() {}


bool HcclTask::Distribute() { bool HcclTask::Distribute() {
// Ops kernel info store // Ops kernel info store
@@ -79,11 +71,7 @@ bool HcclTask::Distribute() {
SetSecondaryStream(); SetSecondaryStream();


if (task_info_->workspace_size() > 0) { if (task_info_->workspace_size() > 0) {
rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
workspace_mem_ = task_info_->workspace_addr();
} }


GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl.");


+ 42
- 16
ge/ge_runtime/task/label_goto_task.cc View File

@@ -16,33 +16,46 @@


#include "ge_runtime/task/label_goto_task.h" #include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h" #include "ge_runtime/task/task_factory.h"
#include "framework/common/util.h"


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), task_info_(task_info) {
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
index_value_(nullptr) {
if (task_info_ == nullptr) { if (task_info_ == nullptr) {
GELOGW("task_info_ is null!"); GELOGW("task_info_ is null!");
return; return;
} }
auto stream_list = model_context.stream_list(); auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list(); auto label_list = model_context.label_list();
rt_model_handle_ = model_context.rt_model_handle();
uint32_t stream_id = task_info->stream_id(); uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
label_id_ = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id_);
if (stream_id >= stream_list.size() || label_id_ >= label_list.size()) {
GELOGW("Stream/Label id invalid."); GELOGW("Stream/Label id invalid.");
return; return;
} }
stream_ = stream_list[stream_id]; stream_ = stream_list[stream_id];
label_ = label_list[label_id];
label_manager_ = LabelManager::GetInstance();
if (label_manager_ == nullptr) {
GELOGW("Get label manager instance failed.");
return;
}
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, {label_id_}, label_list);
} }


LabelGotoTask::~LabelGotoTask() { LabelGotoTask::~LabelGotoTask() {
GE_FREE_RT_LOG(label_info_);
GE_FREE_RT_LOG(index_value_);
if (index_value_ != nullptr) {
rtError_t rt_ret = rtFree(index_value_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree index_value_ failed! ret: 0x%X.", rt_ret);
}
index_value_ = nullptr;
}
} }


bool LabelGotoTask::Distribute() { bool LabelGotoTask::Distribute() {
@@ -72,7 +85,7 @@ bool LabelGotoTask::Distribute() {
return false; return false;
} }


rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
rt_ret = rtLabelListCpy(reinterpret_cast<void**>(label_list.data()), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false; return false;
@@ -94,21 +107,34 @@ bool LabelGotoTask::CheckParamValid() {
return false; return false;
} }


if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
if (label_info_ == nullptr) {
GELOGE(PARAM_INVALID, "label info is null!");
return false; return false;
} }


if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
if (index_value_ == nullptr) {
rtError_t rt_ret = rtMalloc(&index_value_, sizeof(uint64_t), RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

uint64_t index = 0;
rt_ret = rtMemcpy(index_value_, sizeof(uint64_t), &index, sizeof(index), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
} }


if (index_value_ != nullptr) {
GELOGE(PARAM_INVALID, "index_value_ has dirty data.");
void *label_info = label_info_->GetLabelInfo();
rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, 1, label_info, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false; return false;
} }


GELOGI("DistributeTask end.");
return true; return true;
} }




+ 10
- 6
ge/ge_runtime/task/label_goto_task.h View File

@@ -18,7 +18,11 @@
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_


#include <memory> #include <memory>
#include <vector>
#include <map>
#include <mutex>
#include "ge_runtime/task/task.h" #include "ge_runtime/task/task.h"
#include "ge_runtime/task/label_manager.h"


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
@@ -31,13 +35,13 @@ class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
bool Distribute() override; bool Distribute() override;


private: private:
bool CheckParamValid();

std::shared_ptr<LabelGotoTaskInfo> task_info_; std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_{nullptr};
void *label_{nullptr};
void *label_info_{nullptr};
void *index_value_{nullptr};
void *stream_;
std::shared_ptr<LabelGuard> label_info_;
void *index_value_;
uint32_t label_id_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelManager> label_manager_;
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 119
- 0
ge/ge_runtime/task/label_manager.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_manager.h"
#include <algorithm>
#include <string>
#include "runtime/mem.h"
#include "runtime/rt_model.h"
#include "common/ge_inner_error_codes.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace model_runner {
std::weak_ptr<LabelManager> LabelManager::instance_;
std::mutex LabelManager::instance_mutex_;

template <class T>
static std::string GetVectorString(const std::vector<T> &vec) {
std::string ret;
for (size_t i = 0; i < vec.size(); ++i) {
if (i != 0) {
ret.push_back(',');
}
ret += std::to_string(vec[i]);
}
return ret;
}

LabelGuard::~LabelGuard() {
void *label_info = GetLabelInfo();
if (label_info != nullptr) {
rtError_t rt_ret = rtFree(label_info);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree label_info failed! ret: 0x%X.", rt_ret);
}
}
}

std::shared_ptr<LabelManager> LabelManager::GetInstance() {
std::lock_guard<std::mutex> lock(instance_mutex_);
auto instance = instance_.lock();
if (instance != nullptr) {
return instance;
}

instance = std::make_shared<LabelManager>();
instance_ = instance;
return instance;
}

std::shared_ptr<LabelGuard> LabelManager::GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label) {
std::lock_guard<std::mutex> lock(model_info_mapping_mutex_);
rtError_t rt_ret;
auto model_iter = model_info_mapping_.find(model);
if (model_iter == model_info_mapping_.end()) {
model_info_mapping_.emplace(model, std::map<std::string, std::weak_ptr<LabelGuard>>());
model_iter = model_info_mapping_.find(model);
}

std::string label_id_str = GetVectorString(label_ids);
auto &label_map = model_iter->second;
auto label_iter = label_map.find(label_id_str);
if (label_iter != label_map.end()) {
auto label_guard = label_iter->second.lock();
if (label_guard != nullptr) {
GELOGI("model %p find same label id %s.", model, label_id_str.c_str());
return label_guard;
}
}

GELOGI("Alloc label id %s for model %p.", label_id_str.c_str(), model);
void *label_info;
std::vector<void *> label_list;
bool status = true;
std::transform(label_ids.begin(), label_ids.end(), std::back_inserter(label_list),
[&all_label, &status](uint32_t idx) -> void * {
if (idx >= all_label.size()) {
GELOGE(PARAM_INVALID, "Invalid label id %u, all label list size %zu.", idx, all_label.size());
status = false;
return nullptr;
}
return all_label[idx];
});
if (!status) {
GELOGE(PARAM_INVALID, "Get label info failed.");
return nullptr;
}
uint32_t label_info_size = sizeof(rtLabelDevInfo) * label_list.size();
rt_ret = rtMalloc(&label_info, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return nullptr;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return nullptr;
}

auto label_guard = std::make_shared<LabelGuard>(label_info);
label_map.emplace(label_id_str, label_guard);
return label_guard;
}
} // namespace model_runner
} // namespace ge

+ 54
- 0
ge/ge_runtime/task/label_manager.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_
#define GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_

#include <vector>
#include <memory>
#include <mutex>
#include <map>
#include <runtime/base.h>

namespace ge {
namespace model_runner {
class LabelGuard {
public:
explicit LabelGuard(void *label_info) : label_info_(reinterpret_cast<uintptr_t>(label_info)) {}
~LabelGuard();
void *GetLabelInfo() { return reinterpret_cast<void *>(label_info_); }

private:
uintptr_t label_info_;
};

class LabelManager {
public:
static std::shared_ptr<LabelManager> GetInstance();
std::shared_ptr<LabelGuard> GetLabelInfo(rtModel_t model, const std::vector<uint32_t> &label_ids,
const std::vector<void *> &all_label);

private:
std::mutex model_info_mapping_mutex_;
std::map<rtModel_t, std::map<std::string, std::weak_ptr<LabelGuard>>> model_info_mapping_;

static std::weak_ptr<LabelManager> instance_;
static std::mutex instance_mutex_;
};


} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_MANAGER_H_

+ 11
- 14
ge/ge_runtime/task/label_switch_task.cc View File

@@ -24,14 +24,14 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info), task_info_(task_info),
stream_(nullptr), stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) { label_info_(nullptr) {
if (task_info_ == nullptr) { if (task_info_ == nullptr) {
GELOGW("task_info_ is null!"); GELOGW("task_info_ is null!");
return; return;
} }


all_label_resource_ = model_context.label_list();
rt_model_handle_ = model_context.rt_model_handle();
auto all_label_resource = model_context.label_list();
auto stream_list = model_context.stream_list(); auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id(); uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
@@ -40,18 +40,16 @@ LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
return; return;
} }
stream_ = stream_list[stream_id]; stream_ = stream_list[stream_id];
}

LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
label_manager_ = LabelManager::GetInstance();
if (label_manager_ == nullptr) {
GELOGW("Get label manager instance failed.");
return;
} }
label_info_ = label_manager_->GetLabelInfo(rt_model_handle_, task_info_->label_list(), all_label_resource);
} }


LabelSwitchTask::~LabelSwitchTask() {}

bool LabelSwitchTask::Distribute() { bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start."); GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) { if (!CheckParamValid()) {
@@ -117,8 +115,8 @@ bool LabelSwitchTask::CheckParamValid() {
return false; return false;
} }


if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
if (label_info_ == nullptr) {
GELOGE(PARAM_INVALID, "CopyLabelList failed, label info is null.");
return false; return false;
} }


@@ -126,6 +124,5 @@ bool LabelSwitchTask::CheckParamValid() {
} }


REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge

+ 4
- 2
ge/ge_runtime/task/label_switch_task.h View File

@@ -19,6 +19,7 @@


#include <memory> #include <memory>
#include "ge_runtime/task/task.h" #include "ge_runtime/task/task.h"
#include "ge_runtime/task/label_manager.h"


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
@@ -35,8 +36,9 @@ class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {


std::shared_ptr<LabelSwitchTaskInfo> task_info_; std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_; void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
rtModel_t rt_model_handle_;
std::shared_ptr<LabelGuard> label_info_;
std::shared_ptr<LabelManager> label_manager_;
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 6
- 5
ge/generator/ge_generator.cc View File

@@ -674,6 +674,12 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GELOGD("Current ctx is null."); GELOGD("Current ctx is null.");
ctx = nullptr; ctx = nullptr;
} }
std::function<void()> callback = [&]() {
if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}
};
GE_MAKE_GUARD(restore, callback);


GeRootModelPtr ge_root_model = nullptr; GeRootModelPtr ge_root_model = nullptr;
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
@@ -712,11 +718,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
} }
return ret; return ret;
} }

if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}

return SUCCESS; return SUCCESS;
} }




+ 5
- 0
ge/graph/build/label_allocator.cc View File

@@ -86,6 +86,11 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node
return false; return false;
} }


if (func_node->GetOpDesc() != nullptr && func_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Graph[%s] is ffts subgraph, skip label allocator.", graph->GetName().c_str());
return true;
}

ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph();
if (owner_graph == nullptr) { if (owner_graph == nullptr) {
REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s", REPORT_INNER_ERROR("E19999", "ComputeGraph owner not set in node:%s(%s), graph:%s",


+ 5
- 0
ge/graph/build/logical_stream_allocator.cc View File

@@ -474,6 +474,11 @@ Status UpdateForSkippedEnginePass::Run(ComputeGraphPtr graph, const vector<Subgr
for (ge::NodePtr &node : graph->GetDirectNode()) { for (ge::NodePtr &node : graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
if (op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID)) {
op_desc->SetStreamId(kInvalidStream);
GELOGI("Ffts node %s of type %s reassign to invalid stream.", node->GetName().c_str(), node->GetType().c_str());
continue;
}
int64_t stream_id = op_desc->GetStreamId(); int64_t stream_id = op_desc->GetStreamId();
if (ops_without_label.find(op_desc) != ops_without_label.end()) { if (ops_without_label.find(op_desc) != ops_without_label.end()) {
if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) {


+ 1
- 1
ge/graph/build/model_builder.cc View File

@@ -707,7 +707,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) {
if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) {
GE_CHECK_NOTNULL(kernel_buffer.GetData()); GE_CHECK_NOTNULL(kernel_buffer.GetData());
std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize());
tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data));
tbe_kernel = MakeShared<OpKernelBin>(kernel_name, std::move(data));
GE_CHECK_NOTNULL(tbe_kernel); GE_CHECK_NOTNULL(tbe_kernel);
GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(),
node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str());


+ 10
- 1
ge/graph/build/stream_allocator.cc View File

@@ -432,7 +432,11 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() {


// Insert the send/recv event id to the graph // Insert the send/recv event id to the graph
Status StreamAllocator::InsertSyncEvents() { Status StreamAllocator::InsertSyncEvents() {
for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};

for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
// Take the adjacent points, then judge whether need to insert the event // Take the adjacent points, then judge whether need to insert the event
for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) {
for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) {
@@ -531,6 +535,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const
Status StreamAllocator::InsertEventsForSubgraph() { Status StreamAllocator::InsertEventsForSubgraph() {
for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph); GE_CHECK_NOTNULL(subgraph);
const auto parent_node = subgraph->GetParentNode();
if (parent_node != nullptr && parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GELOGD("Skip ffts subgraph, parent node is %s.", parent_node->GetName().c_str());
continue;
}
for (const auto &node : subgraph->GetDirectNode()) { for (const auto &node : subgraph->GetDirectNode()) {
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


+ 47
- 29
ge/graph/build/task_generator.cc View File

@@ -354,7 +354,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
}; };
GE_MAKE_GUARD(release, callback); GE_MAKE_GUARD(release, callback);


for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
auto ffts_filter = [](const Node &node, const char *, const ComputeGraphPtr &) {
return !node.GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH);
};
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag(), nullptr, ffts_filter)) {
OpDescPtr op_desc = node->GetOpDesc(); OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
node_index++; node_index++;
@@ -380,10 +383,8 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str());
continue; continue;
} }
if (op_kernel_lib_name.empty()) {
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
continue;
}
GE_CHK_BOOL_EXEC_INFO(!op_kernel_lib_name.empty(), continue,
"Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str());
auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); auto kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name);
if (kernel_info_store == nullptr) { if (kernel_info_store == nullptr) {
REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s", REPORT_INNER_ERROR("E19999", "Get ops kernel info store failed for op:%s(%s), op_kernel_name:%s",
@@ -394,6 +395,10 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra
} }
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(), GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "[Call][UpdateAnchorStatus] node:%s(%s) failed", name.c_str(),
type.c_str()); type.c_str());
if (node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_SUB_GRAPH)) {
GE_CHK_STATUS_RET(UpdateAnchorStatusForFfts(node), "[Call][UpdateAnchorStatusForFfts] node:%s(%s) failed",
name.c_str(), type.c_str());
}
// Profiling task // Profiling task
size_t task_list_size_before = task_def_list.size(); size_t task_list_size_before = task_def_list.size();
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list));
@@ -571,7 +576,24 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info
return ret; return ret;
} }


Status TaskGenerator::UpdateAnchorStatusForFfts(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatusForFfts for %s.", node->GetName().c_str());
if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_graph = NodeUtils::GetSubgraph(*node, i);
GE_CHECK_NOTNULL(sub_graph);
GELOGD("Start update anchor status for %s.", sub_graph->GetName().c_str());
for (auto &ffts_node : sub_graph->GetDirectNode()) {
GE_CHK_STATUS_RET(UpdateAnchorStatus(ffts_node), "[Call][UpdateAnchorStatus] node:%s(%s) failed",
ffts_node->GetName().c_str(), ffts_node->GetType().c_str());
}
}
}
return SUCCESS;
}

Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) {
GELOGD("Start UpdateAnchorStatus for %s.", node->GetName().c_str());
if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) { if (NodeUtils::SetAllAnchorStatus(node) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)", REPORT_CALL_ERROR("E19999", "SetAllAnchorStatus fail for op:%s(%s)",
node->GetName().c_str(), node->GetType().c_str()); node->GetName().c_str(), node->GetType().c_str());
@@ -771,7 +793,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
GELOGI("Start AutoFindBpOpIndex"); GELOGI("Start AutoFindBpOpIndex");
NodePtr bp_node = nullptr; NodePtr bp_node = nullptr;
uint32_t current_idx = 0; uint32_t current_idx = 0;
uint32_t netoutput_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
OpDescPtr op_desc = node->GetOpDesc(); OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
@@ -789,7 +810,6 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) { if (op_desc->GetName() == NODE_NAME_NET_OUTPUT) {
if (bp_node == nullptr) { if (bp_node == nullptr) {
bp_node = node; bp_node = node;
netoutput_idx = current_idx - 1;
} }
} }
if (graph->GetNeedIteration()) { if (graph->GetNeedIteration()) {
@@ -814,34 +834,30 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (bp_node == nullptr) { if (bp_node == nullptr) {
GELOGW("not find bp_node."); GELOGW("not find bp_node.");
return SUCCESS; return SUCCESS;
} else if (bp_node->GetName() == NODE_NAME_NET_OUTPUT) {
profiling_point.bp_index = netoutput_idx;
GELOGI("First bp name %s, idx %u", bp_node->GetName().c_str(), netoutput_idx);
} else {
profiling_point.bp_index = FindLastBpFromBpNode(graph, bp_node);
} }


return SUCCESS;
return FindLastBpFromBpNode(graph, bp_node, profiling_point.bp_index);
} }


uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const {
uint32_t last_bp = 0;
Status TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &target_node,
uint32_t &bp_index) const {
bp_index = 0;
auto target_desc = target_node->GetOpDesc();
GE_CHECK_NOTNULL(target_desc);
OpDescPtr bp_op_desc = nullptr; OpDescPtr bp_op_desc = nullptr;
for (auto &in_anchor : bp_node->GetAllInDataAnchors()) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) {
continue;
for (auto &in_node : target_node->GetInAllNodes()) {
GE_CHECK_NOTNULL(in_node);
auto in_node_desc = in_node->GetOpDesc();
GE_CHECK_NOTNULL(in_node_desc);
if ((bp_op_desc == nullptr || (in_node_desc->GetId() > bp_op_desc->GetId())) &&
(in_node_desc->GetStreamId() == target_desc->GetStreamId())){
bp_op_desc = in_node_desc;
} }
auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc();
GE_CHECK_NOTNULL(out_node_desc);
if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) {
bp_op_desc = out_node_desc;
}
GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId());
} }


if (bp_op_desc == nullptr) { if (bp_op_desc == nullptr) {
return last_bp;
GELOGI("Did not find bp node.");
return SUCCESS;
} }
uint32_t current_idx = 0; uint32_t current_idx = 0;
for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) {
@@ -849,12 +865,14 @@ uint32_t TaskGenerator::FindLastBpFromBpNode(const ComputeGraphPtr &graph, const
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
current_idx++; current_idx++;
if (op_desc->GetName() == bp_op_desc->GetName()) { if (op_desc->GetName() == bp_op_desc->GetName()) {
last_bp = current_idx;
GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp);
bp_index = current_idx;
GELOGI("Find bp name %s, idx %u", op_desc->GetName().c_str(), bp_index);
break; break;
} }
} }
return last_bp;
GELOGI("Last bp node[%s], type[%s], index[%u], stream id[%ld]", bp_op_desc->GetName().c_str(),
bp_op_desc->GetType().c_str(), bp_index, bp_op_desc->GetStreamId());
return SUCCESS;
} }


Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,


+ 2
- 1
ge/graph/build/task_generator.h View File

@@ -80,6 +80,7 @@ class TaskGenerator {
Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
std::vector<uint32_t> &all_reduce_nodes); std::vector<uint32_t> &all_reduce_nodes);
private: private:
Status UpdateAnchorStatusForFfts(const NodePtr &node);
Status UpdateAnchorStatus(const NodePtr &node); Status UpdateAnchorStatus(const NodePtr &node);


Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);
@@ -115,7 +116,7 @@ class TaskGenerator {
Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
vector<uint32_t> &all_reduce_nodes) const; vector<uint32_t> &all_reduce_nodes) const;
uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;


Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
ProfilingPoint &profiling_point) const; ProfilingPoint &profiling_point) const;


+ 0
- 15
ge/graph/common/omg_util.cc View File

@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
} }


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) {
return;
}

SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 0
- 9
ge/graph/common/omg_util.h View File

@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size);
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 158
- 49
ge/graph/load/model_manager/davinci_model.cc View File

@@ -99,6 +99,9 @@ const uint32_t kEndOfSequenceNew = 507005;
const int32_t kModelAbortNormal = 0x0704000e; const int32_t kModelAbortNormal = 0x0704000e;
const int32_t kModelAbortNormalNew = 507024; const int32_t kModelAbortNormalNew = 507024;
const uint32_t kInteval = 2; const uint32_t kInteval = 2;
const uint32_t kFftsTbeHandleElementSize = 2;
const uint32_t kNonTailBlock = 0;
const uint32_t kTailBlock = 1;
const char *const kModelName = "model_name"; const char *const kModelName = "model_name";
const char *const kModeleId = "model_id"; const char *const kModeleId = "model_id";
const char *const kLoadStartTime = "load_start_time"; const char *const kLoadStartTime = "load_start_time";
@@ -116,14 +119,15 @@ const char *const kWorkSpaceSize = "workspace_size";
const char *const kTotalSize = "total_size"; const char *const kTotalSize = "total_size";
const char *const kTaskCount = "task_count"; const char *const kTaskCount = "task_count";
const char *const kTaskId = "task_id"; const char *const kTaskId = "task_id";
const char* const kRequestId = "request_id";
const char* const kThreadId = "thread_id";
const char* const kInputBeginTime = "input_begin_time";
const char* const kInputEndTime = "input_end_time";
const char* const kInferBeginTime = "infer_begin_time";
const char* const kInferEndTime = "infer_end_time";
const char* const kOutputBeginTime = "output_start_time";
const char* const kOutputEndTime = "output_end_time";
const char *const kRequestId = "request_id";
const char *const kThreadId = "thread_id";
const char *const kInputBeginTime = "input_begin_time";
const char *const kInputEndTime = "input_end_time";
const char *const kInferBeginTime = "infer_begin_time";
const char *const kInferEndTime = "infer_end_time";
const char *const kOutputBeginTime = "output_start_time";
const char *const kOutputEndTime = "output_end_time";
const char *const kStubFuncName = "_register_stub_func";
const uint32_t kStringHeadElems = 2; const uint32_t kStringHeadElems = 2;
const uint32_t kPlacementHostData = 0; const uint32_t kPlacementHostData = 0;
const size_t kAlignment = 64; const size_t kAlignment = 64;
@@ -902,10 +906,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {
SetLabelForDynamic(node); SetLabelForDynamic(node);
auto it = op_desc_handle.find(op_desc->GetType()); auto it = op_desc_handle.find(op_desc->GetType());
if (it != op_desc_handle.end()) { if (it != op_desc_handle.end()) {
if ((this->*it->second)(op_desc) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
return PARAM_INVALID;
}
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((this->*it->second)(op_desc) != SUCCESS, return PARAM_INVALID,
"[Init][Node] failed, Name:%s", op_desc->GetName().c_str());
continue; continue;
} }


@@ -935,7 +937,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) {


GE_TIMESTAMP_RESTART(InitTbeHandle); GE_TIMESTAMP_RESTART(InitTbeHandle);
if (IsTbeTask(op_desc)) { if (IsTbeTask(op_desc)) {
Status status = InitTbeHandle(op_desc);
Status status =
op_desc->HasAttr(ATTR_NAME_THREAD_SCOPE_ID) ? InitTbeHandleWithFfts(op_desc) : InitTbeHandle(op_desc);
if (status != SUCCESS) { if (status != SUCCESS) {
GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str()); GELOGE(status, "[Init][TbeHandle] failed. op:%s", op_desc->GetName().c_str());
return status; return status;
@@ -3700,6 +3703,7 @@ Status DavinciModel::InitConstant(const OpDescPtr &op_desc) {
/// @return Status /// @return Status
/// ///
Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) { Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
string bin_file = op_desc->GetName();
auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName()); auto kernel = ge_model_->GetTBEKernelStore().FindKernel(op_desc->GetName());
auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); auto tbe_kernel = (kernel != nullptr) ? kernel : op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
if (tbe_kernel == nullptr) { if (tbe_kernel == nullptr) {
@@ -3708,12 +3712,61 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file!", op_desc->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file, tbe_kernel, false), "Function register of bin file: %s failed",
bin_file.c_str());
return SUCCESS;
}


std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
const char *bin_file_key = GetRegisterStub(op_desc->GetName(), session_graph_model_id); // from set, always valid.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
Status DavinciModel::InitTbeHandleWithFfts(const OpDescPtr &op_desc) {
std::vector<OpKernelBinPtr> tbe_kernel;
tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_THREAD_TBE_KERNEL, tbe_kernel);
GELOGD("Kernel bin ptr vec size is %zu.", tbe_kernel.size());
if (tbe_kernel.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get tbe_kernel for op:%s(%s) fail, model_id:%u",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find tvm bin file, size is %zu when ffts",
op_desc->GetName().c_str(), tbe_kernel.size());
return INTERNAL_ERROR;
}
if (tbe_kernel[0] == nullptr || tbe_kernel[1] == nullptr) {
REPORT_INNER_ERROR("E19999", "Tbe kernel for op:%s is nullptr.", op_desc->GetName().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: tvm bin file of %s is nullptr when ffts.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
vector<string> bin_file_keys;
(void)AttrUtils::GetListStr(op_desc, kStubFuncName, bin_file_keys);
if (bin_file_keys.size() != kFftsTbeHandleElementSize) {
REPORT_INNER_ERROR("E19999", "Get bin_file for op:%s(%s) fail.", op_desc->GetName().c_str(),
op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] TBE: %s can't find bin file keys, size is %zu when ffts",
op_desc->GetName().c_str(), bin_file_keys.size());
return INTERNAL_ERROR;
}
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kNonTailBlock], tbe_kernel[kNonTailBlock], true,
kNonTailBlock),
"Function register of first bin file %s failed.", bin_file_keys[kNonTailBlock].c_str());
GE_CHK_STATUS_RET(FunctionRegister(op_desc, bin_file_keys[kTailBlock], tbe_kernel[kTailBlock], true, kTailBlock),
"Function register of second bin file %s failed.", bin_file_keys[kTailBlock].c_str());
return SUCCESS;
}


Status DavinciModel::FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel,
bool is_ffts, size_t thread_index) {
if (thread_index > 1) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Thread index: %zu should less than 1.", thread_index);
return INTERNAL_ERROR;
}
const char *bin_file_key;
if (is_ffts) {
bin_file_key = GetRegisterStub(bin_file, "");
GELOGI("Node:%s inherit func name:%s directly.", op_desc->GetName().c_str(), bin_file_key);
} else {
std::string session_graph_model_id;
GetUniqueId(op_desc, session_graph_model_id);
bin_file_key = GetRegisterStub(bin_file, session_graph_model_id); // from set, always valid.
}

TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();
std::lock_guard<std::mutex> lock(tvm_bin_mutex_); std::lock_guard<std::mutex> lock(tvm_bin_mutex_);
if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) { if (rtQueryFunctionRegistered(bin_file_key) != RT_ERROR_NONE) {
void *bin_handle = nullptr; void *bin_handle = nullptr;
@@ -3721,59 +3774,115 @@ Status DavinciModel::InitTbeHandle(const OpDescPtr &op_desc) {
GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key); GELOGD("TBE: can't find the kernel_name[%s] in HandleMap", bin_file_key);


rtDevBinary_t binary; rtDevBinary_t binary;
std::string json_string;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_MAGIC, json_string),
GELOGD("Get original type of session_graph_id."));
if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICPU") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICPU;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AIVEC") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AIVEC;
} else if (json_string == "RT_DEV_BINARY_MAGIC_ELF_AICUBE") {
binary.magic = RT_DEV_BINARY_MAGIC_ELF_AICUBE;
} else {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}

GE_CHK_STATUS_RET(InitBinaryMagic(op_desc, is_ffts, thread_index, binary), "Init binary magic of %s failed.",
op_desc->GetName().c_str());
binary.version = 0; binary.version = 0;
binary.data = tbe_kernel->GetBinData(); binary.data = tbe_kernel->GetBinData();
binary.length = tbe_kernel->GetBinDataSize(); binary.length = tbe_kernel->GetBinDataSize();

GELOGD("TBE: binary.length: %lu", binary.length); GELOGD("TBE: binary.length: %lu", binary.length);
GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle)); GE_CHK_RT_RET(rtDevBinaryRegister(&binary, &bin_handle));


std::string meta_data;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, TVM_ATTR_NAME_METADATA, meta_data),
GELOGI("Get original type of json_string"));
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
GE_IF_BOOL_EXEC(!meta_data.empty(), GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str())));

GE_CHK_STATUS_RET(InitMetaData(op_desc, is_ffts, thread_index, bin_handle), "Init tvm meta data of %s failed.",
op_desc->GetName().c_str());
kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel); kernel_store.StoreTBEHandle(bin_file_key, bin_handle, tbe_kernel);
} else { } else {
GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key); GELOGI("TBE: find the kernel_name[%s] in HandleMap", bin_file_key);
kernel_store.ReferTBEHandle(bin_file_key); kernel_store.ReferTBEHandle(bin_file_key);
} }

std::string kernel_name; std::string kernel_name;
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name),
GELOGD("Get original type of kernel_name"));
GE_CHK_STATUS_RET(InitKernelName(op_desc, is_ffts, thread_index, kernel_name), "Init kernel name of %s failed.",
op_desc->GetName().c_str());
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0)); GE_CHK_RT_RET(rtFunctionRegister(bin_handle, bin_file_key, bin_file_key, kernel_name.c_str(), 0));
used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1. used_tbe_handle_map_[bin_file_key] = 1; // Init used num to 1.
return SUCCESS; return SUCCESS;
} }

// Kernel registed, Increase used num in store. // Kernel registed, Increase used num in store.
StoreTbeHandle(bin_file_key); StoreTbeHandle(bin_file_key);
return SUCCESS; return SUCCESS;
} }


Status DavinciModel::InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index,
rtDevBinary_t &binary) {
string json_string;
const string &tvm_magic = is_ffts ? TVM_ATTR_NAME_THREAD_MAGIC : TVM_ATTR_NAME_MAGIC;
const static std::map<std::string, uint32_t> binary_magics = {
{"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU},
{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF},
{"RT_DEV_BINARY_MAGIC_ELF_AIVEC", RT_DEV_BINARY_MAGIC_ELF_AIVEC},
{"RT_DEV_BINARY_MAGIC_ELF_AICUBE", RT_DEV_BINARY_MAGIC_ELF_AICUBE}
};
if (is_ffts) {
vector<string> json_list;
(void)AttrUtils::GetListStr(op_desc, tvm_magic, json_list);
if (json_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Attr is %s, thread index is %zu, json list size is %zu.",
tvm_magic.c_str(), thread_index, json_list.size());
return INTERNAL_ERROR;
}
json_string = json_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_magic, json_string);
}
auto iter = binary_magics.find(json_string);
if (iter == binary_magics.end()) {
REPORT_INNER_ERROR("E19999", "Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
tvm_magic.c_str(), json_string.c_str(), op_desc->GetName().c_str(),
op_desc->GetType().c_str(), model_id_);
GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s value:%s in op:%s(%s), model_id:%u, check invalid",
TVM_ATTR_NAME_MAGIC.c_str(), json_string.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str(), model_id_);
return PARAM_INVALID;
}
binary.magic = iter->second;
return SUCCESS;
}

Status DavinciModel::InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle) {
string meta_data;
const string &tvm_metadata = is_ffts ? TVM_ATTR_NAME_THREAD_METADATA : TVM_ATTR_NAME_METADATA;
if (is_ffts) {
vector<string> meta_data_list;
(void)AttrUtils::GetListStr(op_desc, tvm_metadata, meta_data_list);
if (meta_data_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, meta data list size is %zu.",
tvm_metadata.c_str(), thread_index, meta_data_list.size());
return INTERNAL_ERROR;
}
meta_data = meta_data_list[thread_index];
} else {
(void)AttrUtils::GetStr(op_desc, tvm_metadata, meta_data);
}
GELOGD("TBE: meta data: %s", meta_data.empty() ? "null" : meta_data.c_str());
if (!meta_data.empty()) {
GE_CHK_RT_RET(rtMetadataRegister(bin_handle, meta_data.c_str()));
}
return SUCCESS;
}

Status DavinciModel::InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name) {
if (is_ffts) {
// delete prefix, eg: *sgt_graph_nodes*/loss_scale/gradient/fp32_vals/Mean_grad/Tile
vector<string> kernel_name_list;
auto pos = op_desc->GetName().find("/");
if (pos == std::string::npos) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, subgraph node name: %s.", op_desc->GetName().c_str());
return INTERNAL_ERROR;
}
string attr_kernel_name = op_desc->GetName().substr(pos + 1) + "_thread_kernelname";
(void)AttrUtils::GetListStr(op_desc, attr_kernel_name, kernel_name_list);
if (kernel_name_list.size() != kFftsTbeHandleElementSize) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed, attr is %s, thread index is %zu, kernel name list size is %zu.",
attr_kernel_name.c_str(), thread_index, kernel_name_list.size());
return INTERNAL_ERROR;
}
kernel_name = kernel_name_list[thread_index];
} else {
string attr_kernel_name = op_desc->GetName() + "_kernelname";
(void)AttrUtils::GetStr(op_desc, attr_kernel_name, kernel_name);
}
return SUCCESS;
}

void DavinciModel::StoreTbeHandle(const std::string &handle_key) { void DavinciModel::StoreTbeHandle(const std::string &handle_key) {
// Online mode FE may call rtFunctionRegister. // Online mode FE may call rtFunctionRegister.
TBEHandleStore &kernel_store = TBEHandleStore::GetInstance(); TBEHandleStore &kernel_store = TBEHandleStore::GetInstance();


+ 6
- 0
ge/graph/load/model_manager/davinci_model.h View File

@@ -771,6 +771,12 @@ class DavinciModel {
/// @return Status /// @return Status
/// ///
Status InitTbeHandle(const OpDescPtr &op_desc); Status InitTbeHandle(const OpDescPtr &op_desc);
Status InitTbeHandleWithFfts(const OpDescPtr &op_desc);
Status FunctionRegister(const OpDescPtr &op_desc, string &bin_file, OpKernelBinPtr &tbe_kernel, bool is_ffts,
size_t thread_index = 0);
Status InitBinaryMagic(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, rtDevBinary_t &binary);
Status InitMetaData(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, void *bin_handle);
Status InitKernelName(const OpDescPtr &op_desc, bool is_ffts, size_t thread_index, string &kernel_name);


void StoreTbeHandle(const string &handle_key); void StoreTbeHandle(const string &handle_key);
void CleanTbeHandle(); void CleanTbeHandle();


+ 14
- 9
ge/graph/load/model_manager/model_manager.cc View File

@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str());
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); std::lock_guard<std::mutex> lock(cust_aicpu_mutex_);
if (cust_aicpu_so_.size() == 0) return SUCCESS;
if (cust_aicpu_so_.empty()) {
return SUCCESS;
}
// get current context // get current context
rtContext_t rt_cur_ctx = nullptr; rtContext_t rt_cur_ctx = nullptr;
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); auto rt_error = rtCtxGetCurrent(&rt_cur_ctx);
@@ -1394,9 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
return SUCCESS; return SUCCESS;
} }


rtStream_t stream = nullptr;
vector<void *> allocated_mem; vector<void *> allocated_mem;
std::function<void()> callback = [&]() {
for (auto mem : allocated_mem) {
GE_CHK_RT(rtFree(mem));
}
if (stream != nullptr) {
GE_CHK_RT(rtStreamDestroy(stream));
}
};
GE_MAKE_GUARD(release, callback);

rtError_t status; rtError_t status;
rtStream_t stream = nullptr;
vector<CustAicpuSoBuf> v_cust_so; vector<CustAicpuSoBuf> v_cust_so;
void *args = nullptr; void *args = nullptr;


@@ -1471,13 +1483,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) {
GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status);
return RT_ERROR_TO_GE_STATUS(status); return RT_ERROR_TO_GE_STATUS(status);
} }
std::function<void()> callback = [&]() {
for (auto mem : allocated_mem) {
GE_CHK_RT(rtFree(mem));
}
GE_CHK_RT(rtStreamDestroy(stream));
};
GE_MAKE_GUARD(release, callback);
GELOGI("Cpu kernel launch task success."); GELOGI("Cpu kernel launch task success.");
return SUCCESS; return SUCCESS;
} }


+ 393
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.cc View File

@@ -0,0 +1,393 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/load/model_manager/task_info/ffts_task_info.h"

#include <vector>

#include "graph/load/model_manager/davinci_model.h"

namespace {
constexpr uint32_t kAddrLen = sizeof(void *);
}
namespace ge {
FftsTaskInfo::~FftsTaskInfo() {
GE_FREE_RT_LOG(args_);
}

Status FftsTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
GELOGI("FftsTaskInfo Init Start.");
GE_CHECK_NOTNULL(davinci_model);
davinci_model_ = davinci_model;
GE_CHK_STATUS_RET_NOLOG(SetStream(task_def.stream_id(), davinci_model_->GetStreamList()));

const domi::FftsTaskDef &ffts_task_def = task_def.ffts_task();
OpDescPtr op_desc = davinci_model_->GetOpByIndex(ffts_task_def.op_index());
GE_CHECK_NOTNULL(op_desc);

if ((ffts_task_def.sub_task_size() > static_cast<int>(RT_FFTS_MAX_SUB_TASK_NUM)) ||
(ffts_task_def.ticket_cache_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_NUM))) {
GELOGE(INTERNAL_ERROR, "[Check][Param] failed. Node: %s, sub task desc size: %d, ticket cache size: %d",
op_desc->GetName().c_str(), ffts_task_def.sub_task_size(), ffts_task_def.ticket_cache_size());
return INTERNAL_ERROR;
}
args_size_ = kAddrLen * ffts_task_def.addr_size();
GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM));
InitFftsDescInfo(ffts_task_def.ffts_desc(), sub_task_info_.fftsDesc);

sub_task_info_.fftsType = static_cast<rtFftsType_t>(ffts_task_def.ffts_type());
sub_task_info_.subTaskNum = ffts_task_def.sub_task_size();
for (int idx = 0; idx < ffts_task_def.sub_task_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitSubTaskInfo(ffts_task_def.sub_task(idx), sub_task_info_.subTask[idx]));
}

sub_task_info_.tickCacheNum = ffts_task_def.ticket_cache_size();
for (int idx = 0; idx < ffts_task_def.ticket_cache_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitTicketCache(ffts_task_def.ticket_cache(idx), sub_task_info_.ticketCache[idx]));
}

size_t data_size = kAddrLen * io_addrs_.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs_.data(), data_size, RT_MEMCPY_HOST_TO_DEVICE));
GELOGI("FftsTaskInfo::Init Success. Node: %s, input/output size: %zu", op_desc->GetName().c_str(), io_addrs_.size());
return SUCCESS;
}

void FftsTaskInfo::InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc) {
ffts_desc.tm = static_cast<uint8_t>(ffts_desc_def.tm());
ffts_desc.di = static_cast<uint8_t>(ffts_desc_def.di());
ffts_desc.dw = static_cast<uint8_t>(ffts_desc_def.dw());
ffts_desc.df = static_cast<uint8_t>(ffts_desc_def.df());
ffts_desc.dataSplitUnit = static_cast<uint8_t>(ffts_desc_def.data_split_unit());
ffts_desc.prefetchOstNum = static_cast<uint8_t>(ffts_desc_def.prefetch_ost_num());
ffts_desc.cacheMaintainOstNum = static_cast<uint8_t>(ffts_desc_def.cache_maintain_ost_num());
ffts_desc.aicPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_upper());
ffts_desc.aicPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aic_prefetch_lower());
ffts_desc.aivPrefetchUpper = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_upper());
ffts_desc.aivPrefetchLower = static_cast<uint8_t>(ffts_desc_def.aiv_prefetch_lower());
}

Status FftsTaskInfo::InitSubTaskInfo(const domi::FftsSubTaskDef &sub_task_def, rtFftsSubTaskInfo_t &sub_task_desc) {
if ((sub_task_def.dst_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) ||
(sub_task_def.src_tick_cache_id_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, dst tick cache id size: %d, src tick cache id size: %d",
sub_task_def.dst_tick_cache_id_size(), sub_task_def.src_tick_cache_id_size());
return FAILED;
}

if (sub_task_def.has_auto_thread_aic_aiv() == sub_task_def.has_manual_thread_aic_aiv()) {
GELOGE(FAILED, "[Check][Param] Invalid FftsSubTaskInfo, auto thread aic/aiv: %d, manual thread aic/aiv: %d",
sub_task_def.has_auto_thread_aic_aiv(), sub_task_def.has_manual_thread_aic_aiv());
return FAILED;
}

thread_dim_ = sub_task_def.thread_dim();
GE_CHK_BOOL_RET_STATUS(thread_dim_ != 0, FAILED, "[Get][thread_dim] failed, Invalid thread dim: %u!", thread_dim_);

sub_task_desc.subTaskType = static_cast<rtFftsSubTaskType_t>(sub_task_def.sub_task_type());
sub_task_desc.threadDim = sub_task_def.thread_dim();

sub_task_desc.dstTickCacheVldBitmap = sub_task_def.dst_tick_cache_vld_bitmap();
sub_task_desc.srcTickCacheVldBitmap = sub_task_def.src_tick_cache_vld_bitmap();
sub_task_desc.srcDataOutOfSubGraphBitmap = sub_task_def.src_data_out_of_subgraph_bitmap();

for (int idx = 0; idx < sub_task_def.dst_tick_cache_id_size(); ++idx) {
sub_task_desc.dstTickCacheID[idx] = sub_task_def.dst_tick_cache_id(idx);
}

for (int idx = 0; idx < sub_task_def.src_tick_cache_id_size(); ++idx) {
sub_task_desc.srcTickCacheID[idx] = sub_task_def.src_tick_cache_id(idx);
}

if (sub_task_def.has_auto_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(InitAutoAicAiv(sub_task_def.auto_thread_aic_aiv(), sub_task_desc.custom.autoThreadAicAiv));
}

if (sub_task_def.has_manual_thread_aic_aiv()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualAicAiv(sub_task_def.manual_thread_aic_aiv(), sub_task_desc.custom.manualThreadAicAiv));
}

if (sub_task_def.has_manual_thread_nop()) {
GE_CHK_STATUS_RET_NOLOG(InitManualNop(sub_task_def.manual_thread_nop(), sub_task_desc.custom.manualThreadNop));
}

return SUCCESS;
}

Status FftsTaskInfo::InitTicketCache(const domi::TicketCacheDef &ticket_cache_def, rtTicketCache_t &ticket_cache) {
if (ticket_cache_def.has_auto_thread_cache() == ticket_cache_def.has_manual_thread_cache()) {
GELOGE(FAILED, "[Check][Param] Invalid TicketCacheDef, has auto thread cache: %d, has manual thread cache: %d",
ticket_cache_def.has_auto_thread_cache(), ticket_cache_def.has_manual_thread_cache());
return FAILED;
}

ticket_cache.cacheOption = static_cast<rtCacheOp_t>(ticket_cache_def.cache_option());
ticket_cache.ticketCacheWindow = ticket_cache_def.ticket_cache_window();

if (ticket_cache_def.has_auto_thread_cache()) {
InitAutoCacheInfo(ticket_cache_def.auto_thread_cache(), ticket_cache.custom.autoThreadCache);
}
if (ticket_cache_def.has_manual_thread_cache()) {
GE_CHK_STATUS_RET_NOLOG(
InitManualCacheInfo(ticket_cache_def.manual_thread_cache(), ticket_cache.custom.manualThreadCache));
}

return SUCCESS;
}

// task_addr = {0,200,700,1000,2000, 3500}
// task_addr_offset = {20,40,2,100,200}
template <typename T>
Status FftsTaskInfo::InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim,
uint32_t addr_count) {
for (uint32_t i = 0; i < addr_count; ++i) {
uintptr_t logic_addr = aic_aiv_def.task_addr(i) + thread_dim * aic_aiv_def.task_addr_offset(i);
uint8_t *io_addr = nullptr;
if (ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][GetRtAddress]GetRtAddress failed.");
return INTERNAL_ERROR;
}
GELOGD("aic_aiv_def task base addr is %ld, offset is %ld, thread is %d, logic addrs is 0x%lx, io addr is %p",
aic_aiv_def.task_addr(i), aic_aiv_def.task_addr_offset(i), thread_dim, logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}
return SUCCESS;
}

Status FftsTaskInfo::InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv) {
if (aic_aiv_def.src_prefetch_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid AutoThreadAicAivInfo, prefetch size: %d", aic_aiv_def.src_prefetch_size());
return FAILED;
}

aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("AutoThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
GELOGD("logic addr is 0x%lx, io addr is %p.", logic_addr, io_addr);
io_addrs_.emplace_back(io_addr);
}

aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();
GELOGD("args_: %p, io_addrs size: %zu, task param offset: %u.", args_, io_addrs_.size(), aic_aiv.taskParamOffset);
aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap();
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap();

aic_aiv.tailBlkDim = aic_aiv_def.tail_blk_dim();
aic_aiv.nonTailBlkDim = aic_aiv_def.non_tail_blk_dim();

aic_aiv.nonTailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.non_tail_task_func_stub(), "");
aic_aiv.tailTaskFuncStub = davinci_model_->GetRegisterStub(aic_aiv_def.tail_task_func_stub(), "");

GELOGI("Set func name[%s][%s] succ.", aic_aiv.nonTailTaskFuncStub, aic_aiv.tailTaskFuncStub);
for (int idx = 0; idx < aic_aiv_def.src_prefetch_size(); ++idx) {
InitAutoPrefetch(aic_aiv_def.src_prefetch(idx), aic_aiv.srcPrefetch[idx]);
}

return SUCCESS;
}

void FftsTaskInfo::InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache) {
cache.dataAddr = cache_def.data_addr();
cache.dataAddrOffset = cache_def.data_addr_offset();
cache.nonTailDataLen = cache_def.non_tail_data_len();
cache.tailDataLen = cache_def.tail_data_len();
cache.ticketCacheRefCnt = cache_def.ticket_cache_ref_cnt();
}

void FftsTaskInfo::InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch) {
prefetch.dataAddr = prefetch_def.data_addr();
prefetch.dataAddrOffset = prefetch_def.data_addr_offset();
prefetch.nonTailDataLen = prefetch_def.non_tail_data_len();
prefetch.tailDataLen = prefetch_def.tail_data_len();
}

Status FftsTaskInfo::InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def,
rtManualThreadAicAivInfo_t &aic_aiv) {
if ((aic_aiv_def.thread_prefetch_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_blk_dim_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.thread_task_func_stub_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(aic_aiv_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadAicAivInfo, thread prefetch dmu desc size: %d, "
"thread blk dim size: %d, thread task func stub size: %d, src dep tbl size: %d",
aic_aiv_def.thread_prefetch_dmu_idx_size(), aic_aiv_def.thread_blk_dim_size(),
aic_aiv_def.thread_task_func_stub_size(), aic_aiv_def.src_dep_tbl_size());
return FAILED;
}
aic_aiv.taskParamAddr = reinterpret_cast<uintptr_t>(args_) + kAddrLen * io_addrs_.size();
GELOGD("ManualThreadAicAivDef: task param addr is %lu.", aic_aiv.taskParamAddr);
const auto &rts_param = davinci_model_->GetRuntimeParam();
for (uint32_t i = 0; i < thread_dim_ - 1; ++i) {
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, i,
static_cast<uint32_t>(aic_aiv_def.task_addr_offset_size())));
}
GE_CHK_STATUS_RET_NOLOG(InitIoAddrs(rts_param, aic_aiv_def, thread_dim_ - 1, aic_aiv_def.input_output_count()));
int last_thread_workspace_size = aic_aiv_def.task_addr_size() - aic_aiv_def.task_addr_offset_size();
for (int k = 0; k < last_thread_workspace_size; ++k) {
uintptr_t logic_addr = aic_aiv_def.task_addr(aic_aiv_def.task_addr_offset_size() + k);
uint8_t *io_addr = nullptr;
GE_CHK_STATUS_RET_NOLOG(ModelUtils::GetRtAddress(rts_param, logic_addr, io_addr));
io_addrs_.emplace_back(io_addr);
}
aic_aiv.taskParamOffset = aic_aiv_def.task_param_offset();

aic_aiv.satMode = aic_aiv_def.sat_mode();
aic_aiv.scheduleMode = aic_aiv_def.schedule_mode();
aic_aiv.iCachePrefetchCnt = aic_aiv_def.cache_prefetch_cnt();

aic_aiv.prefetchEnableBitmap = aic_aiv_def.prefetch_enable_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceBitmap = aic_aiv_def.prefetch_once_bitmap(); // 8 bit bitmap 1 0 1 0
aic_aiv.prefetchOnceDmuNum = aic_aiv_def.prefetch_once_dmu_num();

for (int idx = 0; idx < aic_aiv_def.thread_prefetch_dmu_idx_size(); ++idx) {
aic_aiv.threadPrefetchDmuIdx[idx] = aic_aiv_def.thread_prefetch_dmu_idx(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_blk_dim_size(); ++idx) {
aic_aiv.threadBlkDim[idx] = aic_aiv_def.thread_blk_dim(idx);
}
for (int idx = 0; idx < aic_aiv_def.thread_task_func_stub_size(); ++idx) {
aic_aiv.threadTaskFuncStub[idx] = aic_aiv_def.thread_task_func_stub(idx).c_str();
}

InitManualDmuInfo(aic_aiv_def, aic_aiv.prefetchList);
for (int idx = 0; idx < aic_aiv_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(aic_aiv_def.src_dep_tbl(idx), aic_aiv.srcDepTbl[idx]));
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def,
rtManualThreadCacheInfo_t &cache_info) {
if ((cache_def.slice_dmu_idx_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM)) ||
(cache_def.ticket_cache_ref_cnt_tbl_size() > static_cast<int>(RT_FFTS_MAX_MANUAL_THREAD_NUM))) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadCacheInfo slice dum desc index %d, ticket cache ref cnt %d",
cache_def.slice_dmu_idx_size(), cache_def.ticket_cache_ref_cnt_tbl_size());
return FAILED;
}

InitManualDmuInfo(cache_def, cache_info.dmuList);
for (int idx = 0; idx < cache_def.slice_dmu_idx_size(); ++idx) {
cache_info.sliceDmuIdx[idx] = cache_def.slice_dmu_idx(idx);
}

for (int idx = 0; idx < cache_def.ticket_cache_ref_cnt_tbl_size(); ++idx) {
cache_info.ticketCacheRefCntTbl[idx] = cache_def.ticket_cache_ref_cnt_tbl(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualDependency(const domi::ManualThreadDependencyDef &dependency_def,
rtManualThreadDependency_t &dependency) {
if (dependency_def.dependency_size() > static_cast<int>(RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadDependency size: %d", dependency_def.dependency_size());
return FAILED;
}

for (int idx = 0; idx < dependency_def.dependency_size(); ++idx) {
dependency.dependency[idx] = dependency_def.dependency(idx);
}

return SUCCESS;
}

Status FftsTaskInfo::InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop_info) {
if (nop_def.src_dep_tbl_size() > static_cast<int>(RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK)) {
GELOGE(FAILED, "[Check][Param] Invalid ManualThreadNopInfo, src dep tbl size: %d", nop_def.src_dep_tbl_size());
return FAILED;
}

for (int idx = 0; idx < nop_def.src_dep_tbl_size(); ++idx) {
GE_CHK_STATUS_RET_NOLOG(InitManualDependency(nop_def.src_dep_tbl(idx), nop_info.srcDepTbl[idx]));
}

return SUCCESS;
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu) {
if (aic_aiv_def.prefetch_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * aic_aiv_def.prefetch_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < aic_aiv_def.prefetch_list_size(); ++idx) {
InitManualDmuInfo(aic_aiv_def.prefetch_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu) {
if (cache_def.dmu_list().empty()) {
return;
}

std::vector<uint8_t> buffer(sizeof(rtManualThreadDmuInfo_t) * cache_def.dmu_list_size());
dmu = reinterpret_cast<rtManualThreadDmuInfo_t *>(buffer.data());
for (int idx = 0; idx < cache_def.dmu_list_size(); ++idx) {
InitManualDmuInfo(cache_def.dmu_list(idx), dmu[idx]);
}
}

void FftsTaskInfo::InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu) {
dmu.dataAddr = dmu_def.data_addr();
dmu.numOuter = dmu_def.num_outer();
dmu.numInner = dmu_def.num_inner();
dmu.strideOuter = dmu_def.stride_outer();
dmu.lenInner = dmu_def.len_inner();
dmu.strideInner = dmu_def.stride_inner();
}

Status FftsTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) {
return SUCCESS;
}

Status FftsTaskInfo::UpdateArgs() {
GE_CHECK_NOTNULL(davinci_model_);
std::vector<void *> io_addrs = io_addrs_;
davinci_model_->UpdateKnownZeroCopyAddr(io_addrs);
auto addr_size = kAddrLen * io_addrs.size();
GE_CHK_RT_RET(rtMemcpy(args_, args_size_, io_addrs.data(), addr_size, RT_MEMCPY_HOST_TO_DEVICE));
return SUCCESS;
}

Status FftsTaskInfo::Distribute() {
GELOGI("FftsTaskInfo Distribute Start.");
rtError_t rt_ret = rtFftsTaskLaunch(&sub_task_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "[Check][RT_ret] Call rtFftsTaskLaunch failed, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

GELOGI("FftsTaskInfo Distribute Success.");
return SUCCESS;
}

REGISTER_TASK_INFO(RT_MODEL_TASK_FFTS_TASK, FftsTaskInfo);
} // namespace ge

+ 66
- 0
ge/graph/load/model_manager/task_info/ffts_task_info.h View File

@@ -0,0 +1,66 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

#include "graph/load/model_manager/task_info/task_info.h"
#include "graph/op_desc.h"

namespace ge {
class FftsTaskInfo : public TaskInfo {
public:
FftsTaskInfo() = default;
~FftsTaskInfo() override;

Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

Status Distribute() override;

Status UpdateArgs() override;

Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override;

private:
void InitFftsDescInfo(const domi::FftsDescInfoDef &ffts_desc_def, rtFftsDescInfo_t &ffts_desc);
Status InitSubTaskInfo(const domi::FftsSubTaskDef &task_def, rtFftsSubTaskInfo_t &task);
Status InitTicketCache(const domi::TicketCacheDef &cache_def, rtTicketCache_t &cache);

Status InitAutoAicAiv(const domi::AutoThreadAicAivDef &aic_aiv_def, rtAutoThreadAicAivInfo_t &aic_aiv);
void InitAutoCacheInfo(const domi::AutoThreadCacheDef &cache_def, rtAutoThreadCacheInfo_t &cache);
void InitAutoPrefetch(const domi::AutoThreadPrefetchDef &prefetch_def, rtAutoThreadPrefetch_t &prefetch);

Status InitManualAicAiv(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadAicAivInfo_t &aic_aiv);
Status InitManualCacheInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadCacheInfo_t &cache);
Status InitManualDependency(const domi::ManualThreadDependencyDef &depend_def, rtManualThreadDependency_t &depend);
Status InitManualNop(const domi::ManualThreadNopDef &nop_def, rtManualThreadNopInfo_t &nop);

void InitManualDmuInfo(const domi::ManualThreadDmuDef &dmu_def, rtManualThreadDmuInfo_t &dmu);
void InitManualDmuInfo(const domi::ManualThreadCacheDef &cache_def, rtManualThreadDmuInfo_t *&dmu);
void InitManualDmuInfo(const domi::ManualThreadAicAivDef &aic_aiv_def, rtManualThreadDmuInfo_t *&dmu);

template<typename T>
Status InitIoAddrs(const RuntimeParam &rts_param, const T &aic_aiv_def, uint32_t thread_dim, uint32_t addr_count);

DavinciModel *davinci_model_{nullptr};
rtFftsTaskInfo_t sub_task_info_;
std::vector<void *> io_addrs_;
uint32_t thread_dim_{0};
void *args_{nullptr}; // runtime args memory
uint32_t args_size_{0}; // runtime args memory length
};
} // namespace ge
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_FFTS_TASK_INFO_H_

+ 2
- 0
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -645,6 +645,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);


args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) { if (sec_ret != EOK) {
REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret); REPORT_CALL_ERROR("E19999", "Call memcpy_s fail, size:%u, ret:0x%X", args_size_, sec_ret);
@@ -1000,6 +1001,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k


// copy args to new host memory // copy args to new host memory
args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]); args_addr = std::unique_ptr<uint8_t[]>(new (std::nothrow) uint8_t[args_size_]);
GE_CHECK_NOTNULL(args_addr);
GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_)
errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_);
if (sec_ret != EOK) { if (sec_ret != EOK) {


+ 20
- 26
ge/graph/manager/graph_manager.cc View File

@@ -27,6 +27,7 @@
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
#include "common/dump/dump_manager.h" #include "common/dump/dump_manager.h"
#include "ge_opt_info/ge_opt_info.h"
#include "analyzer/analyzer.h" #include "analyzer/analyzer.h"
#include "graph/common/ge_call_wrapper.h" #include "graph/common/ge_call_wrapper.h"
#include "graph/common/local_context.h" #include "graph/common/local_context.h"
@@ -120,7 +121,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar";
const char *const kCheckPointGraph = "checkpoint_graph"; const char *const kCheckPointGraph = "checkpoint_graph";
const char *const kVectorEngine = "VectorEngine"; const char *const kVectorEngine = "VectorEngine";
const char *const kAIcoreEngine = "AIcoreEngine"; const char *const kAIcoreEngine = "AIcoreEngine";
const char *const kRunFlagOffline = "0";
const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsGetNext = 0;
const int32_t kDynamicDimsTypeIsData = 1; const int32_t kDynamicDimsTypeIsData = 1;
const char *const kGetNextName = "IteratorV2"; const char *const kGetNextName = "IteratorV2";
@@ -950,7 +950,7 @@ Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint


rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate faileded, session_id:%lu, graph_id:%u, mode:%d",
REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, session_id:%lu, graph_id:%u, mode:%d",
session_id, graph_id, mode); session_id, graph_id, mode);
GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode); GELOGE(FAILED, "[Call][RtCtxCreate] faileded, session_id:%lu, graph_id:%u, mode:%d", session_id, graph_id, mode);
return FAILED; return FAILED;
@@ -1002,6 +1002,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge
return ret; return ret;
} }


ret = GeOptInfo::SetOptInfo();
if (ret != SUCCESS) {
GELOGE(ret, "[Set][OptInfo] Set optional information failed.");
return ret;
}

/// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph;
/// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph.
/// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph.
@@ -1789,8 +1795,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID); return GE_GRAPH_OPTIONS_INVALID);


// ge.graphType // ge.graphType
ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS, GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID); return GE_GRAPH_OPTIONS_INVALID);
@@ -2436,6 +2441,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute
continue; continue;
} }
if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) { if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) {
// reset const type depend on train_flag
options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT);
if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) { if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) {
// it is an isolated constant, just remove it // it is an isolated constant, just remove it
if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) { if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) {
@@ -2762,35 +2769,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) {
"Please pay attention to it."); "Please pay attention to it.");
} }


GE_CHK_STATUS_RET(ChangeConstType(compute_graph));
ChangeConstTypeWhenTraining(compute_graph);


GELOGI("End optimize after merge sub graph."); GELOGI("End optimize after merge sub graph.");
return SUCCESS; return SUCCESS;
} }


Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) {
// run_flag off means offline, on means online
string run_flag;
(void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag);
// The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future.
if (run_flag == kRunFlagOffline) {
GELOGI("Offline mode, change all Constant to Const.");
} else {
GELOGI("Online mode, change all Const to Constant.");
}
for (NodePtr &n : compute_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(n);
if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) {
auto op_desc = n->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (run_flag == kRunFlagOffline) {
op_desc->SetType(CONSTANT);
} else {
op_desc->SetType(CONSTANTOP);
void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) {
// The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future.
if (options_.train_graph_flag) {
for (NodePtr &n : compute_graph->GetAllNodes()) {
// This can ensure that n is not a null pointer
if (n->GetOpDesc()->GetType() == CONSTANT) {
n->GetOpDesc()->SetType(CONSTANTOP);
} }
} }
} }
return SUCCESS;
} }


Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
@@ -3145,10 +3139,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
} }
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency
if (count > 1 && graph_node->GetBuildFlag()) { if (count > 1 && graph_node->GetBuildFlag()) {
graph_node->Lock();
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id);
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times
graph_node->SetSemSize(count); graph_node->SetSemSize(count);
graph_node->Lock();
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context,
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback }));
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); GELOGI("[PreRunThread] Loop end. Start to run with cached build model.");


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -375,7 +375,7 @@ class GraphManager {
static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback,
Status ret, const string &log); Status ret, const string &log);


Status ChangeConstType(const ComputeGraphPtr &compute_graph);
void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph);


Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs,
ge::ComputeGraphPtr &compute_graph, uint64_t session_id); ge::ComputeGraphPtr &compute_graph, uint64_t session_id);


+ 0
- 2
ge/graph/optimize/graph_optimize.cc View File

@@ -336,10 +336,8 @@ Status GraphOptimize::OptimizeAfterStage1(ComputeGraphPtr &compute_graph) {
GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str()); GELOGI("[OptimizeAfterStage1]: engine type will exclude:%s.", exclude_core_type.c_str());
continue; continue;
} }
#ifndef ONLY_COMPILE_OPEN_SRC
GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str()); GELOGI("Begin to optimize graph after stage1 by engine %s.", iter->first.c_str());
ret = (iter->second)->OptimizeAfterStage1(*compute_graph); ret = (iter->second)->OptimizeAfterStage1(*compute_graph);
#endif
if (ret != SUCCESS) { if (ret != SUCCESS) {
REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, " REPORT_INNER_ERROR("E19999", "Call OptimizeAfterStage1 failed, ret:%d, engine_name:%s, "
"graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str()); "graph_name:%s.", ret, iter->first.c_str(), compute_graph->GetName().c_str());


+ 29
- 7
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this); auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed.");
node_2_cluster_[node] = cluster; node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster);
}


int64_t group_index = -1; int64_t group_index = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() {
return SUCCESS; return SUCCESS;
} }


Status DynamicShapePartitioner::TopologicalSortClusters() {
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) {
ordered_cluster_.clear(); ordered_cluster_.clear();
// BFS topological sort clusters for known shape cluster // BFS topological sort clusters for known shape cluster
std::queue<ClusterPtr> ready_clusters; std::queue<ClusterPtr> ready_clusters;
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() {
auto cluster = ready_clusters.front(); auto cluster = ready_clusters.front();
ready_clusters.pop(); ready_clusters.pop();
cluster->UpdateRank(rank++); cluster->UpdateRank(rank++);
if (cluster->IsKnownShape() || cluster->IsInputNode()) {
if (ordered_filter == nullptr || ordered_filter(cluster)) {
ordered_cluster_.push_back(cluster); ordered_cluster_.push_back(cluster);
} }
for (const auto &out_cluster : cluster->Outputs()) { for (const auto &out_cluster : cluster->Outputs()) {
@@ -364,6 +361,7 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
} }


void DynamicShapePartitioner::MergeClustersControlFlow() { void DynamicShapePartitioner::MergeClustersControlFlow() {
std::unordered_set<ClusterPtr> all_merged_clusters;
for (const auto &item : control_clusters_) { for (const auto &item : control_clusters_) {
const auto &control_cluster = item.second; const auto &control_cluster = item.second;
auto rit = control_cluster.rbegin(); auto rit = control_cluster.rbegin();
@@ -373,12 +371,21 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
} }


const auto &cluster = *rit; const auto &cluster = *rit;
if (all_merged_clusters.count(cluster) > 0) {
continue;
}

for (++rit; rit != control_cluster.rend(); ++rit) { for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit; const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
continue;
}

auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
ToString(merged_clusters).c_str()); ToString(merged_clusters).c_str());
for (const auto &merged_cluster : merged_clusters) { for (const auto &merged_cluster : merged_clusters) {
all_merged_clusters.emplace(merged_cluster);
for (const auto &node : merged_cluster->Nodes()) { for (const auto &node : merged_cluster->Nodes()) {
node_2_cluster_[node] = cluster; node_2_cluster_[node] = cluster;
} }
@@ -459,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() {
} }


Status DynamicShapePartitioner::MergeClusters() { Status DynamicShapePartitioner::MergeClusters() {
const auto filter_known = [](const ClusterPtr &cluster) {
return cluster->IsKnownShape() || cluster->IsInputNode();
};
const auto filter_unknown = [](const ClusterPtr &cluster) {
return cluster->IsUnknownShape();
};

MergeClustersControlFlow(); MergeClustersControlFlow();
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown),
"[TopologicalSort][Clusters] after merge control flow clusters failed.");
MergeClustersUnknownShape(); MergeClustersUnknownShape();
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known),
"[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
MergeClustersKnownShape(); MergeClustersKnownShape();
MergeClustersInputData(); MergeClustersInputData();
return SUCCESS; return SUCCESS;
@@ -703,7 +720,12 @@ void Cluster::Merge(ClusterPtr other) {
if (other->min_ < min_) { if (other->min_ < min_) {
min_ = other->min_; min_ = other->min_;
} }
};

if (!IsUnknownShape() && other->IsUnknownShape()) {
type_ = UNKNOWN_SHAPE;
}
}

bool Cluster::TryMerge(ClusterPtr other) { bool Cluster::TryMerge(ClusterPtr other) {
std::queue<ClusterPtr> forward_reached; std::queue<ClusterPtr> forward_reached;
forward_reached.push(other); forward_reached.push(other);


+ 4
- 2
ge/graph/partition/dynamic_shape_partition.h View File

@@ -111,6 +111,8 @@ class DynamicShapePartitioner {


Status Partition(); Status Partition();


using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>;

private: private:
Status PartitionImpl(); Status PartitionImpl();
// Collect nodes that satisfy the unknowshape rules: // Collect nodes that satisfy the unknowshape rules:
@@ -138,7 +140,7 @@ class DynamicShapePartitioner {
// Merge clusters step3 // Merge clusters step3
void MergeClustersInputData(); void MergeClustersInputData();
// Topological sort clusters after merge unknown shape clusters. // Topological sort clusters after merge unknown shape clusters.
Status TopologicalSortClusters();
Status TopologicalSortClusters(const OrderedFilter &ordered_filter);
// Deduplicate merged clusters // Deduplicate merged clusters
void PruneUniqueClusters(); void PruneUniqueClusters();
// Establish the input-output anchors for each partition of the cluster and record links to other clusters // Establish the input-output anchors for each partition of the cluster and record links to other clusters
@@ -161,7 +163,7 @@ class DynamicShapePartitioner {
ge::ComputeGraphPtr root_graph_; // The original graph to partition ge::ComputeGraphPtr root_graph_; // The original graph to partition
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to
// V1 control flow cluster, need merge to one Graph. // V1 control flow cluster, need merge to one Graph.
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
std::map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_;
// topological sorted clusters, this field will change with the splitting. // topological sorted clusters, this field will change with the splitting.
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters


+ 14
- 7
ge/graph/partition/graph_partition.cc View File

@@ -179,6 +179,7 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
} }
GE_CHECK_NOTNULL(original_compute_graph); GE_CHECK_NOTNULL(original_compute_graph);
output_merged_compute_graph->SetName(original_compute_graph->GetName());
// partition sub graph // partition sub graph
for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) { for (const auto &sub_graph : original_compute_graph->GetAllSubgraphs()) {
ComputeGraphPtr merged_sub_graph = nullptr; ComputeGraphPtr merged_sub_graph = nullptr;
@@ -188,8 +189,16 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr
GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret); GELOGE(ret, "[Merge][SubGraph] Failed, ret:%d", ret);
continue; continue;
} }
// this means subgraph added in optimize subgraph and without partitions, so just add to root graph
if (merged_sub_graph == sub_graph) {
GELOGI("Just add subgraph %s (parent node is %s) to root graph %s.", sub_graph->GetName().c_str(),
sub_graph->GetParentNode()->GetName().c_str(), output_merged_compute_graph->GetName().c_str());
sub_graph->SetParentGraph(sub_graph->GetParentNode()->GetOwnerComputeGraph());
GE_IF_BOOL_EXEC(output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph) != SUCCESS,
return FAILED;)
continue;
}
// add sub graph // add sub graph
output_merged_compute_graph->SetName(original_compute_graph->GetName());
merged_sub_graph->SetName(sub_graph->GetName()); merged_sub_graph->SetName(sub_graph->GetName());
merged_sub_graph->SetInputSize(sub_graph->GetInputSize()); merged_sub_graph->SetInputSize(sub_graph->GetInputSize());
merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize()); merged_sub_graph->SetOutputSize(sub_graph->GetOutputSize());
@@ -245,12 +254,9 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co
} }
if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) || if ((graph_2_graph_partition_info_.find(original_compute_graph) == graph_2_graph_partition_info_.end()) ||
(graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) { (graph_2_subgraph_list_.find(original_compute_graph) == graph_2_subgraph_list_.end())) {
REPORT_INNER_ERROR("E19999", "original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
GELOGE(GE_GRAPH_NULL_INPUT,
"[Check][Param] original_compute_graph:%s is not find in graph_2_graph_partition_info_.",
original_compute_graph->GetName().c_str());
return FAILED;
GELOGW("[GraphPartition]: compute_graph has not found, just return original.");
output_merged_compute_graph = original_compute_graph;
return SUCCESS;
} }
GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph]; GraphPartitionInfo &subgraph_info = graph_2_graph_partition_info_[original_compute_graph];
const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph]; const auto &sub_graph_list = graph_2_subgraph_list_[original_compute_graph];
@@ -708,6 +714,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr
} }
auto &engine_name = graph_info_.partitions_.at(sub_graph); auto &engine_name = graph_info_.partitions_.at(sub_graph);
(void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName());
(void)sub_graph->SetExtAttr("part_src_graph", compute_graph);
GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(),
compute_graph->GetName().c_str()); compute_graph->GetName().c_str());
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]);


+ 63
- 72
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -16,8 +16,6 @@


#include "mark_force_unknown_for_cond_pass.h" #include "mark_force_unknown_for_cond_pass.h"


#include <queue>

#include "graph/utils/node_utils.h" #include "graph/utils/node_utils.h"
#include "graph/common/omg_util.h" #include "graph/common/omg_util.h"


@@ -26,17 +24,7 @@ namespace {
inline bool IsMergeInLoop(const NodePtr &node) { inline bool IsMergeInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };


std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopMergeInputs.count(node_type) > 0;
}

inline bool IsSwitchInLoop(const NodePtr &node) {
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND };

std::string node_type;
(void)GetOriginalType(node, node_type);
return kLoopSwitchInputs.count(node_type) > 0;
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0;
} }
} }


@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkForceUnknownForCondPass Enter"); GELOGD("MarkForceUnknownForCondPass Enter");
std::map<NodePtr, std::vector<NodePtr>> switch_groups; std::map<NodePtr, std::vector<NodePtr>> switch_groups;
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type),
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str());
if (kMergeOpTypes.count(node_type) == 0) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) {
continue; continue;
} }


@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
} }


/// ///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span,
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) {
/// LoopCond --->\.
/// \.
/// Enter-----------+ \.
/// +--> Merge --> Switch --> Exit
/// NextIteration---+
const auto is_loop_op = [](const NodePtr &n) {
return NodeUtils::GetNodeType(n) == LOOPCOND;
};
const auto is_exit_op = [](const NodePtr &n) {
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0;
};

const auto src_nodes = node->GetInAllNodes();
const auto dst_nodes = node->GetOutAllNodes();
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) &&
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) {
return false;
}

for (const auto &m : src_nodes) {
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) {
for (const auto &n : m->GetInAllNodes()) {
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) {
continue;
}

search_queue.push({n, dst_span});
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(),
n->GetName().c_str(), dst_span);
}
}
}

return true;
}

///
/// @brief Mark force unknown shape for Switch node /// @brief Mark force unknown shape for Switch node
/// @param [in] merge node /// @param [in] merge node
/// @param [out] switch group /// @param [out] switch group
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) {
// Switch --> {Switch --> Merge} --> Merge // Switch --> {Switch --> Merge} --> Merge
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str());
std::unordered_set<NodePtr> nodes_seen; std::unordered_set<NodePtr> nodes_seen;
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}});
while (!search_queue.empty()) { while (!search_queue.empty()) {
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
const auto dst_span = search_queue.front().second; const auto dst_span = search_queue.front().second;
search_queue.pop(); search_queue.pop();


// Switch --> Identity --> Constant
for (const auto &in_node : dst_node->GetInControlNodes()) {
if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue;
}
nodes_seen.insert(in_node);

if (in_node->GetType() == IDENTITY) {
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(),
in_node->GetName().c_str(), dst_span);
search_queue.push({in_node, dst_span});
}
}

for (const auto &in_node : dst_node->GetInDataNodes()) {
for (const auto &in_node : dst_node->GetInAllNodes()) {
if (nodes_seen.count(in_node) > 0) { if (nodes_seen.count(in_node) > 0) {
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str());
continue; continue;
} }
nodes_seen.insert(in_node); nodes_seen.insert(in_node);


std::string node_type;
(void)GetOriginalType(in_node, node_type);
const std::string node_type = NodeUtils::GetNodeType(in_node);
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str(), dst_span); in_node->GetName().c_str(), dst_span);
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node.
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) {
continue;
}

if (dst_span > 0) { if (dst_span > 0) {
search_queue.push({in_node, dst_span - 1}); search_queue.push({in_node, dst_span - 1});
} else { } else {
const auto &all_in_nodes = in_node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) {
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(),
in_node->GetName().c_str());
} else {
switch_group.emplace_back(in_node);
}
switch_group.emplace_back(in_node);
} }
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node.
search_queue.push({in_node, dst_span + 1}); search_queue.push({in_node, dst_span + 1});
@@ -132,39 +145,17 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return /// @return
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
};

for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) {
const auto &op_node = it->first;
const auto &op_desc = op_node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue; continue;
} }


if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) {
int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
MarkForceUnknownShape(op_node1, true, group_index);
for (const auto &n : it1->second) {
MarkForceUnknownShape(n, true, group_index);
}

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
MarkForceUnknownShape(op_node2, true, group_index);
for (const auto &n : it2->second) {
MarkForceUnknownShape(n, true, group_index);
}
}
}
int64_t group_index = op_desc->GetId();
SetControlFlowGroup(op_node, group_index);
for (const auto &n : it->second) {
SetControlFlowGroup(n, group_index);
} }
} }
} }


+ 11
- 0
ge/graph/passes/mark_force_unknown_for_cond_pass.h View File

@@ -19,6 +19,8 @@


#include "inc/graph_pass.h" #include "inc/graph_pass.h"


#include <queue>

namespace ge { namespace ge {
class MarkForceUnknownForCondPass : public GraphPass { class MarkForceUnknownForCondPass : public GraphPass {
public: public:
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass {


private: private:
/// ///
/// @brief Deal with Switch node for LoopCond
/// @param [in] Switch node
/// @param [in] dest span
/// @param [out] Search queue
/// @return true: Switch In while loop / false: Not in while Loop.
///
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue);

///
/// @brief Mark force unknown shape for Switch node /// @brief Mark force unknown shape for Switch node
/// @param [in] merge node /// @param [in] merge node
/// @param [out] switch group /// @param [out] switch group


+ 6
- 0
ge/graph/passes/mark_graph_unknown_status_pass.cc View File

@@ -40,6 +40,12 @@ Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) {
} }
} }


const auto &node = graph->GetParentNode();
if (!is_unknown_shape && node != nullptr && node->GetType() == PARTITIONEDCALL) {
GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape),
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str());
}

for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str()); GELOGD("Set OwnerGraphIsUnknown attr to node[%s]", node->GetName().c_str());
(void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape); (void)AttrUtils::SetBool(node->GetOpDesc(), kOwnerGraphIsUnknown, is_unknown_shape);


+ 2
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "[Check][Param] Param of pre node is nullptr."); return FAILED, "[Check][Param] Param of pre node is nullptr.");
int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str());
return FAILED; return FAILED;
} }
MarkForceUnknownShape(active_node, force_unknown, group_index);
SetControlFlowGroup(active_node, group_index);
} }


return SUCCESS; return SUCCESS;


+ 23
- 6
ge/graph/passes/next_iteration_pass.cc View File

@@ -24,7 +24,9 @@ using std::string;


namespace ge { namespace ge {
namespace { namespace {
const int64_t kLoopType = 1;
constexpr int64_t kLoopType = 1;
constexpr uint8_t kMaxTransOp = 3;
constexpr uint8_t kTransOpIoSize = 1;
} }


Status NextIterationPass::Run(ComputeGraphPtr graph) { Status NextIterationPass::Run(ComputeGraphPtr graph) {
@@ -284,13 +286,28 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
/// @return void /// @return void
/// ///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) {
std::string node_type;
for (const auto &switch_node : loop_group.switch_nodes) { for (const auto &switch_node : loop_group.switch_nodes) {
SetControlFlowGroup(switch_node, group_index); SetControlFlowGroup(switch_node, group_index);
for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type);
if (kExitOpTypes.count(node_type) > 0) {
SetControlFlowGroup(node, group_index);
for (auto node : switch_node->GetOutDataNodes()) {
// Switch --> Exit
// Switch --> Cast --> Exit
// Switch --> TransData --> Cast --> Exit
for (uint8_t i = 0; i < kMaxTransOp; ++i) {
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) {
break;
}

if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
SetControlFlowGroup(node, group_index);
break;
}

const auto &all_nodes = node->GetOutAllNodes();
if (all_nodes.size() != kTransOpIoSize) {
break;
}
node = all_nodes.at(0);
} }
} }
} }


+ 39
- 19
ge/graph/passes/parallel_group_pass.cc View File

@@ -15,7 +15,7 @@
*/ */


#include "graph/passes/parallel_group_pass.h" #include "graph/passes/parallel_group_pass.h"
#include <queue>
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
@@ -299,24 +299,19 @@ Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cu
for (const auto &switch_node : cur_itr->second.first) { for (const auto &switch_node : cur_itr->second.first) {
int64_t pre_id = pre_node->GetOpDesc()->GetId(); int64_t pre_id = pre_node->GetOpDesc()->GetId();
int64_t switch_id = switch_node->GetOpDesc()->GetId(); int64_t switch_id = switch_node->GetOpDesc()->GetId();
// avoid ring
if (pre_id > switch_id) {
auto merge_node = cur_itr->second.second;
if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
} else {
if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
pre_node->GetName().c_str(), switch_node->GetName().c_str());
return FAILED;
}
NodePtr first_node = pre_node;
NodePtr second_node = switch_node;
if (pre_id > switch_id && IsIndirectConnect(switch_node, pre_node)) {
// avoid ring, merge->pre_node
first_node = cur_itr->second.second;
second_node = pre_node;
}
if (AddCtrlEdge(first_node, second_node) != SUCCESS) {
GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.",
first_node->GetName().c_str(), second_node->GetName().c_str());
return FAILED;
} }
} }
} else { } else {
@@ -345,4 +340,29 @@ bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) {
return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) &&
stream_switch_type == kLoopType); stream_switch_type == kLoopType);
} }

bool ParallelGroupPass::IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b) {
if (node_a == nullptr || node_b == nullptr) {
GELOGW("node_a or node_b is nullptr.");
return false;
}
int64_t end_id = node_b->GetOpDesc()->GetId();
std::queue<NodePtr> nodes;
nodes.push(node_a);
while (!nodes.empty()) {
NodePtr tmp_node = nodes.front();
nodes.pop();
if (tmp_node == nullptr || tmp_node->GetOpDesc() == nullptr ||
tmp_node->GetOpDesc()->GetId() > end_id) {
continue;
}
if (tmp_node == node_b) {
return true;
}
for (const auto &out_node : tmp_node->GetOutAllNodes()) {
nodes.push(out_node);
}
}
return false;
}
} // namespace ge } // namespace ge

+ 1
- 0
ge/graph/passes/parallel_group_pass.h View File

@@ -48,6 +48,7 @@ class ParallelGroupPass : public GraphPass {


bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc); bool IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc);
bool IsWhileStreamSwitch(OpDescPtr switch_op_desc); bool IsWhileStreamSwitch(OpDescPtr switch_op_desc);
bool IsIndirectConnect(const NodePtr &node_a, const NodePtr &node_b);
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H #endif // GE_GRAPH_PASSES_PARALLEL_GROUP_PASS_H

+ 20
- 0
ge/graph/passes/replace_with_empty_const_pass.cc View File

@@ -21,7 +21,23 @@
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_inner_error_codes.h"
#include "graph/utils/graph_utils.h" #include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"


namespace {
const std::unordered_set<std::string> kControlFlowOps = {
ge::SWITCH,
ge::REFSWITCH,
ge::MERGE,
ge::REFMERGE,
ge::ENTER,
ge::REFENTER,
ge::NEXTITERATION,
ge::REFNEXTITERATION,
ge::EXIT,
ge::REFEXIT,
ge::LOOPCOND
};
}
namespace ge { namespace ge {
Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGD("ReplaceWithEmptyConstPass in."); GELOGD("ReplaceWithEmptyConstPass in.");
@@ -39,6 +55,10 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) {
GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str()); GELOGI("Node %s is const. Ignore current pass.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }
if (kControlFlowOps.count(NodeUtils::GetNodeType(node)) != 0) {
GELOGI("Node %s is control flow op. Ignore current pass.", node->GetName().c_str());
return SUCCESS;
}
// Node like no op, it has no output // Node like no op, it has no output
if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) {
GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str());


+ 9
- 8
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());


int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
SetControlFlowGroup(stream_switch, group_index);
}
return stream_switch; return stream_switch;
} }


@@ -491,8 +492,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
std::set<NodePtr> same_cond_switch; std::set<NodePtr> same_cond_switch;
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
@@ -524,13 +525,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
}; };
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
SetControlFlowGroup(active_node, group_index);


const std::string &cond_group = cond_node->GetName(); const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
GE_IF_BOOL_EXEC(switch_list.empty(), continue); GE_IF_BOOL_EXEC(switch_list.empty(), continue);


// select first stream_switch // select first stream_switch
@@ -559,7 +560,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
"[Add][Edge] between %s and %s failed.", "[Add][Edge] between %s and %s failed.",
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); cast_node->GetName().c_str(), stream_switch->GetName().c_str());


MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
SetControlFlowGroup(stream_switch, group_index);
for (const NodePtr &node : switch_list) { for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, { GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 66
- 53
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1420,9 +1420,10 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) {
return SUCCESS; return SUCCESS;
} }


Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag) {
Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc) {
auto format = desc.GetFormat(); auto format = desc.GetFormat();
auto origin_format = desc.GetOriginFormat(); auto origin_format = desc.GetOriginFormat();
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag); bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op) && (!tune_flag);
if (need_check_internal_format) { if (need_check_internal_format) {
bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format);
@@ -1439,6 +1440,63 @@ Status GraphPrepare::CheckInternalFormat(const NodePtr &input_node, const GeTens
return SUCCESS; return SUCCESS;
} }


Status GraphPrepare::UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc) {
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index); return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);

auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());
}

return SUCCESS;
}

Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input, Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
const std::map<string, string> &graph_option) { const std::map<string, string> &graph_option) {
// Get shape range of input in dynamic_execute mode // Get shape range of input in dynamic_execute mode
@@ -1471,63 +1529,18 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input,
} }
GeTensorDesc desc(user_input[index].GetTensorDesc()); GeTensorDesc desc(user_input[index].GetTensorDesc());
// data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM.
auto tune_flag = (options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_BUILDER);
ret = CheckInternalFormat(input_node, desc, tune_flag);
ret = CheckInternalFormat(input_node, desc);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Check][InternalFormat] on %s failed", op->GetName().c_str());
return ret; return ret;
} }
auto data_type = desc.GetDataType();
uint32_t length = 1;
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length);
if (!type_ret) {
std::string reason = "Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "] of index:" +
std::to_string(index) + " input tensor is not support";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] Input datatype %s is not support.",
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
int64_t desc_shape = desc.GetShape().GetShapeSize();
FMK_INT64_UINT32_MULCHECK(desc_shape, length);
int64_t shape_size = desc_shape * length;
GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast<int64_t>(length));
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS,
REPORT_CALL_ERROR("E19999", "Get size of user input tensor failed, index:%ld", index);
GELOGE(INTERNAL_ERROR, "[Get][Size] of user input tensor failed, index:%ld", index);
return FAILED);
bool size_check = (size != 0 && shape_size != size);
if (size_check) {
std::string reason = "input tensor[index:" + std::to_string(index) + "]'s data size[" + std::to_string(size) +
"] != shape_size[" + std::to_string(size) + "], check invalid";
REPORT_INPUT_ERROR("E19025", std::vector<std::string>({"reason"}), std::vector<std::string>({reason}));
GELOGE(PARAM_INVALID, "[Check][Param] input data size = %ld, shape_size = %ld.", size, shape_size);
return FAILED;
}
ge::TensorUtils::SetSize(desc, shape_size);
if (!tune_flag) {
graphStatus graph_ret = op->UpdateInputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update input desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][InputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
// Size will be recalculated in the build stage
ge::TensorUtils::SetSize(desc, 0);
graph_ret = op->UpdateOutputDesc(0, desc);
if (graph_ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Update output desc of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
GELOGE(graph_ret, "[Update][OutputDesc] of op:%s(%s) failed, index:0",
op->GetName().c_str(), op->GetType().c_str());
return graph_ret;
}
} else {
GELOGI("data %s skip update info in tune mode", op->GetName().c_str());

ret = UpdateDataInputOutputDesc(index, op, desc);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Update][DataInputOutputDesc] on %s failed", op->GetName().c_str());
return ret;
} }

if (!dynamic_shape_range_vec.empty()) { if (!dynamic_shape_range_vec.empty()) {
ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc); ret = UpdateDynamicInputShapeRange(index, dynamic_shape_range_vec, op, desc);
GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str()); GE_CHK_STATUS_RET(ret, "[Update][DynamicInputShapeRange] on %s failed.", op->GetName().c_str());


+ 2
- 1
ge/graph/preprocess/graph_preprocess.h View File

@@ -63,7 +63,8 @@ class GraphPrepare {
Status CheckRefOp(); Status CheckRefOp();
Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode);
Status AdjustDataOpOutput(const NodePtr &node); Status AdjustDataOpOutput(const NodePtr &node);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc, bool tune_flag);
Status CheckInternalFormat(const NodePtr &input_node, const GeTensorDesc &desc);
Status UpdateDataInputOutputDesc(GeAttrValue::INT index, OpDescPtr &op, GeTensorDesc &desc);
Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); Status UpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option); Status CheckAndUpdateInput(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option);
Status CheckConstOp(); Status CheckConstOp();


+ 1
- 0
ge/graph/preprocess/insert_op/util_insert_aipp_op.cc View File

@@ -568,6 +568,7 @@ Status InsertNewOpUtil::GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std:
} }


std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams()); std::unique_ptr<domi::AippOpParams> aipp_params(new (std::nothrow) domi::AippOpParams());
GE_CHECK_NOTNULL(aipp_params);
ge::GeAttrValue::NAMED_ATTRS aipp_attr; ge::GeAttrValue::NAMED_ATTRS aipp_attr;
GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST, GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(data_op, ATTR_NAME_AIPP, aipp_attr), ACL_ERROR_GE_AIPP_NOT_EXIST,
"[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str()); "[Get][Attr] %s from op:%s failed", ATTR_NAME_AIPP.c_str(), data_op->GetName().c_str());


+ 1
- 1
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -1206,7 +1206,7 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
if (!IsAllDimsPositive(dims)) { if (!IsAllDimsPositive(dims)) {
REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s", REPORT_CALL_ERROR("E19999", "Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
node->GetName().c_str(), formats::ShapeToString(dims).c_str());
GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s", GELOGE(INTERNAL_ERROR, "[Check][Param] Failed to copy multi batch graph, the node %s still has unknown shape %s",
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); node->GetName().c_str(), formats::ShapeToString(dims).c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;


+ 8
- 0
ge/host_kernels/fill_kernel.cc View File

@@ -45,6 +45,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr.");
return PARAM_INVALID; return PARAM_INVALID;
} }
GELOGD("FillKernel in, name: %s.", op_desc_ptr->GetName().c_str());


GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDimsInputIndex));
GE_CHECK_NOTNULL(input.at(kFillDataInputIndex)); GE_CHECK_NOTNULL(input.at(kFillDataInputIndex));
@@ -57,6 +58,13 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
return NOT_CHANGED; return NOT_CHANGED;
} }


auto output_desc = op_desc_ptr->GetOutputDescPtr(0);
GE_CHECK_NOTNULL(output_desc);
if (output_desc->GetShape().IsUnknownShape()) {
GELOGD("Output is unknown shape, [%s] skip FillKernel.", op_desc_ptr->GetName().c_str());
return NOT_CHANGED;
}

GeTensorPtr output_ptr; GeTensorPtr output_ptr;
output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0)); output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
if (output_ptr == nullptr) { if (output_ptr == nullptr) {


+ 5
- 3
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -297,13 +297,15 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, Hy
} }
} }
tensor_desc->SetShape(shape); tensor_desc->SetShape(shape);
args.input_desc[input_index] = tensor_desc;
GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str());
GELOGD("Update shape[%s] of input[%zu] to [%s]",
shape.ToString().c_str(), input_index, tensor_desc->MutableShape().ToString().c_str());
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size),
"[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size," "[Invoke][GetTensorMemorySizeInBytes]Failed to calc tensor size,"
"index = %zu, shape = [%s], model_id = %u.", "index = %zu, shape = [%s], model_id = %u.",
input_index, tensor_desc->GetShape().ToString().c_str(), model_id_); input_index, tensor_desc->GetShape().ToString().c_str(), model_id_);
GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size);
GELOGD("Input tensor[%zu] size = %ld", input_index, tensor_size);
TensorUtils::SetSize(*tensor_desc, tensor_size);
args.input_desc[input_index] = tensor_desc;
} }


GE_CHECK_GE(tensor_size, 0); GE_CHECK_GE(tensor_size, 0);


+ 8
- 7
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -41,6 +41,8 @@ HybridModelExecutor::~HybridModelExecutor() {
Status HybridModelExecutor::Init() { Status HybridModelExecutor::Init() {
GELOGD("Start to init HybridGraphEngine."); GELOGD("Start to init HybridGraphEngine.");
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext());
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_));
GE_CHECK_NOTNULL(root_graph_executor_);
GELOGD("HybridGraphEngine initialized successfully."); GELOGD("HybridGraphEngine initialized successfully.");
return SUCCESS; return SUCCESS;
} }
@@ -60,8 +62,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration, GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream)); sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
} }
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
auto ret = ExecuteGraphInternal(executor, args);
auto ret = ExecuteGraphInternal(args);
Cleanup(); Cleanup();
RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End");
GELOGD("Model executed successfully."); GELOGD("Model executed successfully.");
@@ -69,6 +70,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
context_.profiler->Dump(std::cout); context_.profiler->Dump(std::cout);
context_.profiler->Reset(); context_.profiler->Reset();
} }
root_graph_executor_->ReleaseContext();


context_.iteration += 1; context_.iteration += 1;
if (ret == END_OF_SEQUENCE) { if (ret == END_OF_SEQUENCE) {
@@ -79,8 +81,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
return SUCCESS; return SUCCESS;
} }


Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
HybridModelExecutor::ExecuteArgs &args) {
Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) {
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start");
GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_));
RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End");
@@ -94,7 +95,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id)); GE_CHK_STATUS_RET_NOLOG(prof_mgr.ProfileStepInfo(index_id, model_id, 0, stream_, device_id));
} }


HYBRID_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc, args.outputs),
HYBRID_CHK_STATUS_RET(root_graph_executor_->ExecuteAsync(args.inputs, args.input_desc, args.outputs),
"Failed to execute partitioned call."); "Failed to execute partitioned call.");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End");


@@ -103,7 +104,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
} }


if (!model_->IsSingleOp()) { if (!model_->IsSingleOp()) {
Status ret = executor.Synchronize();
Status ret = root_graph_executor_->Synchronize();
if (ret != ge::SUCCESS) { if (ret != ge::SUCCESS) {
auto model_manager = ModelManager::GetInstance(); auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager); GE_CHECK_NOTNULL(model_manager);
@@ -123,7 +124,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor,
} }


args.outputs.clear(); args.outputs.clear();
HYBRID_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
HYBRID_CHK_STATUS_RET(root_graph_executor_->GetOutputs(args.outputs, args.output_desc), "Failed to get outputs");
RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End");
return SUCCESS; return SUCCESS;
} }


+ 2
- 1
ge/hybrid/executor/hybrid_model_executor.h View File

@@ -48,7 +48,7 @@ class HybridModelExecutor {
Status Execute(ExecuteArgs &args); Status Execute(ExecuteArgs &args);


private: private:
Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args);
Status ExecuteGraphInternal(ExecuteArgs &args);
Status Cleanup(); Status Cleanup();
Status InitExecutionContext(); Status InitExecutionContext();
static Status ResetExecutionContext(GraphExecutionContext &context); static Status ResetExecutionContext(GraphExecutionContext &context);
@@ -58,6 +58,7 @@ class HybridModelExecutor {
uint32_t device_id_; uint32_t device_id_;
rtStream_t stream_; rtStream_t stream_;
GraphExecutionContext context_; GraphExecutionContext context_;
std::unique_ptr<SubgraphExecutor> root_graph_executor_;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 79
- 8
ge/hybrid/executor/node_state.cc View File

@@ -19,8 +19,9 @@
#include "framework/common/debug/log.h" #include "framework/common/debug/log.h"
#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid_execution_context.h"
#include "subgraph_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/subgraph_context.h"
#include "hybrid/node_executor/task_context.h"


#define INC_ITERATION_COUNT(iteration) \ #define INC_ITERATION_COUNT(iteration) \
do { \ do { \
@@ -260,6 +261,16 @@ NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_contex
this->op_desc_ = node_item.node->GetOpDesc(); this->op_desc_ = node_item.node->GetOpDesc();
} }


Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) {
GE_CHECK_NOTNULL(frame_state);
group_ = group;
frame_state_ = frame_state;
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());
return SUCCESS;
}

Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const {
if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size()); GELOGD("[%s] merge index %d, input nodes: %zu", GetName().c_str(), merge_index_, node_item_->data_recv_.size());
@@ -314,15 +325,75 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_; return task_context_;
} }


void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) {
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) {
return items.second.count(idx) > 0;
};
return std::any_of(items.begin(), items.end(), is_exist);
};

if (root_tensor_values_.count(input_idx) > 0) {
return;
}

if (is_persist_tensor(node_item_->root_data_, input_idx)) {
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) {
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor;
}
}

void NodeState::UpdatePersistTensor() {
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) {
for (const auto &item : items) {
for (const auto idx : item.second) {
UpdatePersistTensor(idx);
}
}
};

if (root_tensor_values_.empty()) {
return;
}

update_tensor(node_item_->root_data_);
if (iteration_count_ > 0) {
update_tensor(node_item_->enter_data_);
}
}

void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
return;
}

auto tensor = task_context_->MutableInput(input_idx);
if (tensor == nullptr) {
GELOGW("[%s] Not found input tensor: %d", GetName().c_str(), input_idx);
return;
}

*tensor = it->second;
GELOGD("[%s] Update input tensor: %d", GetName().c_str(), input_idx);
}

void NodeState::ResetContext(uint64_t iteration) { void NodeState::ResetContext(uint64_t iteration) {
switch_index_ = -1; switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
if (iteration == 0) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
} else {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size() + node_item_->enter_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size() + node_item_->enter_ctrl_.size());
auto unique_task_context = TaskContext::Create(this, subgraph_context_);
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context);
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release());

data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
} }


iteration_count_ = iteration; iteration_count_ = iteration;


+ 12
- 8
ge/hybrid/executor/node_state.h View File

@@ -100,6 +100,8 @@ struct NodeState {
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context);
~NodeState() = default; ~NodeState() = default;


Status Init(int group, const shared_ptr<FrameState> &frame_state);

OpDesc *GetOpDesc() const { OpDesc *GetOpDesc() const {
return op_desc_.get(); return op_desc_.get();
} }
@@ -129,6 +131,9 @@ struct NodeState {
void RunStreamActive(); void RunStreamActive();
void RunNextIteration(); void RunNextIteration();


void SavePersistTensor(int input_idx, const TensorValue &tensor);
void UpdatePersistTensor();

Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;


void SetScheduleFuture(std::future<Status> &&future); void SetScheduleFuture(std::future<Status> &&future);
@@ -150,18 +155,10 @@ struct NodeState {
return merge_index_; return merge_index_;
} }


void SetGroup(int group) {
group_ = group;
}

int GetGroup() const { int GetGroup() const {
return group_; return group_;
} }


void SetFrameState(const shared_ptr<FrameState> &frame_state) {
frame_state_ = frame_state;
}

const shared_ptr<NodeTask> &GetKernelTask() const { const shared_ptr<NodeTask> &GetKernelTask() const {
return kernel_task_; return kernel_task_;
} }
@@ -181,12 +178,17 @@ struct NodeState {
void SetTaskContext(std::shared_ptr<TaskContext> &task_context); void SetTaskContext(std::shared_ptr<TaskContext> &task_context);
std::shared_ptr<TaskContext> GetTaskContext(); std::shared_ptr<TaskContext> GetTaskContext();


void SetSkipInferShape(bool skip_infershape) { skip_infershape_ = skip_infershape; }

bool MaySkipShapeInference() const { return skip_infershape_; }

private: private:
bool IsScheduleReady() const; bool IsScheduleReady() const;
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration); void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state); void ScheduleContext(const NodeState &node_state);
void UpdatePersistTensor(int input_idx);


const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;
@@ -199,6 +201,7 @@ struct NodeState {


std::future<Status> schedule_future_; std::future<Status> schedule_future_;
std::shared_ptr<FrameState> frame_state_; std::shared_ptr<FrameState> frame_state_;
std::map<int, TensorValue> root_tensor_values_;
uint64_t active_count_ = 0; uint64_t active_count_ = 0;
uint64_t iteration_count_ = 0; uint64_t iteration_count_ = 0;
uint32_t ctrl_scheduled_ = 0; uint32_t ctrl_scheduled_ = 0;
@@ -206,6 +209,7 @@ struct NodeState {
int merge_index_ = -1; // Use for Execute (Reset after Executed). int merge_index_ = -1; // Use for Execute (Reset after Executed).
int switch_index_ = -1; // Use for Schedule (Reset after Prepared). int switch_index_ = -1; // Use for Schedule (Reset after Prepared).
int group_ = -1; int group_ = -1;
bool skip_infershape_ = false;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 19
- 8
ge/hybrid/executor/subgraph_context.cc View File

@@ -19,7 +19,7 @@


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context)
SubgraphContext::SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context)
: graph_item_(graph_item), execution_context_(execution_context) { : graph_item_(graph_item), execution_context_(execution_context) {
} }


@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
return nullptr; return nullptr;
} }


return CreateNodeState(node_item);
}

NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) {
GELOGD("[%s] lock for write", node_item->NodeName().c_str()); GELOGD("[%s] lock for write", node_item->NodeName().c_str());
if (mmRWLockWRLock(&rw_lock_) != EN_OK) { if (mmRWLockWRLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str());
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str());
return nullptr; return nullptr;
} }

auto &node_state = node_states_[node_item]; auto &node_state = node_states_[node_item];
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
node_state->SetFrameState(GetOrCreateFrameState(*node_item));
node_state->SetGroup(group_);
(void)guard;
}
do {
if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state.reset(new(std::nothrow)NodeState(*node_item, this));
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str());
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str());
break;
}
(void)guard;
}
} while (0);

GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); GELOGD("[%s] unlock for write", node_item->NodeName().c_str());
if (mmWRLockUnLock(&rw_lock_) != EN_OK) { if (mmWRLockUnLock(&rw_lock_) != EN_OK) {
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str());


+ 3
- 2
ge/hybrid/executor/subgraph_context.h View File

@@ -30,7 +30,7 @@ namespace ge {
namespace hybrid { namespace hybrid {
class SubgraphContext { class SubgraphContext {
public: public:
explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context);
explicit SubgraphContext(const GraphItem *graph_item, GraphExecutionContext *execution_context);
~SubgraphContext(); ~SubgraphContext();


Status Init(); Status Init();
@@ -51,10 +51,11 @@ class SubgraphContext {
void NodeDone(const NodePtr &node); void NodeDone(const NodePtr &node);


private: private:
NodeStatePtr CreateNodeState(const NodeItem *node_item);
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock
friend class TaskContext; friend class TaskContext;
const GraphItem *graph_item_; const GraphItem *graph_item_;
const GraphExecutionContext *execution_context_;
GraphExecutionContext *execution_context_;
mmRWLock_t rw_lock_; mmRWLock_t rw_lock_;
std::vector<TensorValue> all_inputs_; std::vector<TensorValue> all_inputs_;
std::vector<TensorValue> all_outputs_; std::vector<TensorValue> all_outputs_;


+ 11
- 15
ge/hybrid/executor/subgraph_executor.cc View File

@@ -103,6 +103,14 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue
auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); auto node_state = subgraph_context_->GetOrCreateNodeState(input_node);
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc); node_state->GetShapeInferenceState().UpdateInputShape(0, *tensor_desc);
auto op_desc = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto output_desc = op_desc->MutableOutputDesc(kDataInputIndex);
GE_CHECK_NOTNULL(output_desc);
output_desc->SetShape(tensor_desc->GetShape());
output_desc->SetOriginShape(tensor_desc->GetOriginShape());
output_desc->SetDataType(tensor_desc->GetDataType());
node_state->SetSkipInferShape(true);
} }
} }


@@ -175,16 +183,12 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->SetKernelTask(node_item->kernel_task); node_state->SetKernelTask(node_item->kernel_task);


known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback; std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_, callback),
HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, node_state->GetTaskContext(), *context_, callback),
"[%s] Failed to execute node [%s] for known subgraph.", "[%s] Failed to execute node [%s] for known subgraph.",
graph_item_->GetName().c_str(), graph_item_->GetName().c_str(),
known_shape_task_context_->GetNodeName());
node_state->GetName().c_str());


GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str());
return SUCCESS; return SUCCESS;
@@ -271,16 +275,12 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {
} else { } else {
node_state->SetKernelTask(node_item.kernel_task); node_state->SetKernelTask(node_item.kernel_task);
} }
auto unique_task_context = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state->GetKernelTask(); const auto &task = node_state->GetKernelTask();
if (task == nullptr) { if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Get][KernelTask] failed for[%s], NodeTask is null.", node_state->GetName().c_str());
REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str()); REPORT_CALL_ERROR("E19999", "GetKernelTask failed for %s, nodetask is null.", node_state->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state->SetTaskContext(shared_task_context);
GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state)); GE_CHK_STATUS_RET_NOLOG(NodeEnqueue(p_node_state));
return AfterPrepared(p_node_state); return AfterPrepared(p_node_state);
} }
@@ -480,19 +480,15 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta
} else { } else {
node_state.SetKernelTask(node_item.kernel_task); node_state.SetKernelTask(node_item.kernel_task);
} }
auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get());
GE_CHECK_NOTNULL(unique_task_context);
const auto &task = node_state.GetKernelTask(); const auto &task = node_state.GetKernelTask();
if (task == nullptr) { if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str()); GELOGE(INTERNAL_ERROR, "[Invoke][GetKernelTask] failed for[%s], NodeTask is null.", node_state.GetName().c_str());
REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str()); REPORT_CALL_ERROR("E19999", "invoke GetKernelTask failed for %s, NodeTask is null.", node_state.GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
auto shared_task_context = std::shared_ptr<TaskContext>(unique_task_context.release());
node_state.SetTaskContext(shared_task_context);
GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context)); GE_CHK_RT_RET(rtCtxSetCurrent(ctx->rt_context));
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start"); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] start");
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*shared_task_context)); // update op_desc before alloc ws
GE_CHK_STATUS_RET_NOLOG(task->UpdateTilingData(*node_state.GetTaskContext())); // update op_desc before alloc ws
RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end"); RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[UpdateTilingData] end");
return SUCCESS; return SUCCESS;
} }


+ 2
- 1
ge/hybrid/executor/subgraph_executor.h View File

@@ -41,6 +41,8 @@ class SubgraphExecutor {


Status PartialExecuteAsync(int task_group); Status PartialExecuteAsync(int task_group);


void ReleaseContext() { subgraph_context_.reset(nullptr); }

/** /**
* Execute subgraph async, output tensor address(not data) and output tensor descriptions are * Execute subgraph async, output tensor address(not data) and output tensor descriptions are
* valid after this method returned * valid after this method returned
@@ -125,7 +127,6 @@ class SubgraphExecutor {
ThreadPool pre_run_pool_; ThreadPool pre_run_pool_;
BlockingQueue<NodeState *> ready_queue_; BlockingQueue<NodeState *> ready_queue_;
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_;
std::shared_ptr<TaskContext> known_shape_task_context_;


std::mutex mu_; // Guard for prepare_queues_. std::mutex mu_; // Guard for prepare_queues_.
std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_; std::map<int, BlockingQueue<const NodeItem *>> prepare_queues_;


+ 1
- 0
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state,
auto executor = node_item.node_executor; auto executor = node_item.node_executor;
GE_CHECK_NOTNULL(executor); GE_CHECK_NOTNULL(executor);
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start");
node_state.UpdatePersistTensor();
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.",
node_state.GetName().c_str()); node_state.GetName().c_str());
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End");


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save