Compare commits

...

62 Commits

Author SHA1 Message Date
  储星 a879159cd9 !181 parser one to many add original name 4 years ago
  y00500818 f5b9b56976 parser one to many add original name 4 years ago
  储星 2458a4fb10 !168 mod add control edge log 4 years ago
  root 921c0f4a4a mod add control edge log print 4 years ago
  储星 6ef2a32651 !159 remove specific process when swtich control edge to other node 4 years ago
  wangxiaotian22 4bf949e8f4 remove specific process when swtich control edge to other node 4 years ago
  储星 c1d95bac3a !145 restore code 4 years ago
  y00500818 118afdc74f restore code 4 years ago
  储星 befc2aac08 !141 add validation of fmk type for plugin load. 4 years ago
  y00500818 ed68fe4968 add validation of fmk type for plugin load. 4 years ago
  储星 b5d7c0d3da !140 add validation of fmk type for plugin load. 4 years ago
  y00500818 1196f14a49 add validation of fmk type for plugin load. 4 years ago
  储星 c841458262 !136 ONLY_COMPILE_OPEN_SRC_METADEF 4 years ago
  y00500818 e3607a9bab ONLY_COMPILE_OPEN_SRC_METADEF 4 years ago
  储星 13f2b268d9 !132 remove parser ONLY_COMPILE_OPEN_SRC 4 years ago
  y00500818 62ba68da76 remove parser ONLY_COMPILE_OPEN_SRC 4 years ago
  王涛 3380d801ae !127 parser one to many 4 years ago
  y00500818 1a2a156d7d parser one to many 4 years ago
  王涛 5dd26bcc0c !124 parser one to many 4 years ago
  y00500818 ecfa6f1a12 parser one to many 4 years ago
  王涛 58d657a959 !122 update submodule metadef 4 years ago
  y00500818 11be91b6f5 update submodule metadef 4 years ago
  王涛 76a862b1bc update README.md. 4 years ago
  王涛 4dc4e4c051 !68 solve product side 4 years ago
  t00456437 7c25d2c3be solve product side 4 years ago
  王涛 45cee6f977 !65 add pb2json 4 years ago
  taoxiangdong deea2f1d34 add convert pb2josn 4 years ago
  王涛 87defa07f8 !64 update roadmap 4 years ago
  王涛 7cd0089118 update README.md. 4 years ago
  王涛 99e435028c !52 parser depend on ge in master branch 4 years ago
  taoxiangdong 496594e1a3 update parser denpend on ge 4 years ago
  王涛 b2ce41d5bb !50 update cmake compile dependency 4 years ago
  taoxiangdong 4555f2ded5 update submodule metadef 4 years ago
  taoxiangdong 18d6a45abe update cmake compile dependency 4 years ago
  王涛 5c097bb666 !48 delete parser common convert 4 years ago
  taoxiangdong 112797a552 update parser common convert 4 years ago
  王涛 c6b1f992db !46 update submoduel metadef 4 years ago
  taoxiangdong ca6227490e update submodule metadef 4 years ago
  王涛 f60e66c153 !45 update metadef submoudle 4 years ago
  taoxiangdong 048cd7588f update submodule metadef 4 years ago
  taoxiangdong b97123e9d5 update master parser src code part2 4 years ago
  王涛 6c37290fef update README.md. 4 years ago
  王涛 cc57aded17 update README.md. 4 years ago
  王涛 27e4a3c31e update README.md. 4 years ago
  王涛 f3d7bd35d4 update README.md. 4 years ago
  taoxiangdong cec81adf13 update submodule metadef 4 years ago
  王涛 5fa9c1276e !43 update src code frome yellow zone 4 years ago
  taoxiangdong dcfe59ef7a update master src code from yellow zone 4 years ago
  taoxiangdong fee27781d6 Merge remote-tracking branch 'upstream/development' 4 years ago
  王涛 5ea9437bbe !31 update json cmake 4 years ago
  王涛 be653a1703 !29 update cmake 4 years ago
  王涛 c53d788e91 !27 update submodule 4 years ago
  王涛 763e9d285d !25 updtae submodule metadef 4 years ago
  王涛 191091a3dd !23 sync from yellow zone 20201020 4 years ago
  王涛 88eb8af4ae !21 update git submodule metadef 4 years ago
  王涛 27f3757531 !19 remove compile cache 4 years ago
  王涛 b7702a4fa4 !17 add metadef submodule 4 years ago
  王涛 85cd977a30 !15 update atc cmake 4 years ago
  王涛 240e4efb44 !13 add build.sh 4 years ago
  王涛 996bf32d3e !11 update cmake 4 years ago
  王涛 9a7eb21ea5 !9 update cmakelists 4 years ago
  王涛 7e69a9b628 !7 update readme for parser 4 years ago
93 changed files with 1104 additions and 4454 deletions
Split View
  1. +1
    -1
      .gitmodules
  2. +25
    -3
      README.md
  3. +24
    -3
      build.sh
  4. +9
    -1
      cmake/external_libs/gflags.cmake
  5. +11
    -2
      cmake/external_libs/json.cmake
  6. +9
    -1
      cmake/external_libs/onnx.cmake
  7. +9
    -1
      cmake/external_libs/protobuf_shared.cmake
  8. +9
    -1
      cmake/external_libs/protobuf_static.cmake
  9. +10
    -1
      cmake/external_libs/protoc.cmake
  10. +1
    -1
      metadef
  11. +1
    -1
      parser/CMakeLists.txt
  12. +8
    -9
      parser/caffe/caffe_custom_parser_adapter.cc
  13. +10
    -10
      parser/caffe/caffe_data_parser.cc
  14. +0
    -3
      parser/caffe/caffe_op_parser.cc
  15. +165
    -109
      parser/caffe/caffe_parser.cc
  16. +6
    -6
      parser/caffe/caffe_parser.h
  17. +2
    -4
      parser/caffe/caffe_reshape_parser.cc
  18. +4
    -6
      parser/common/CMakeLists.txt
  19. +8
    -275
      parser/common/acl_graph_parser_util.cc
  20. +2
    -192
      parser/common/acl_graph_parser_util.h
  21. +3
    -3
      parser/common/convert/pb2json.cc
  22. +10
    -10
      parser/common/convert/pb2json.h
  23. +9
    -9
      parser/common/data_op_parser.cc
  24. +1
    -1
      parser/common/data_op_parser.h
  25. +0
    -155
      parser/common/model_saver.cc
  26. +0
    -55
      parser/common/model_saver.h
  27. +14
    -19
      parser/common/module.mk
  28. +1
    -1
      parser/common/op_def/constant_op.h
  29. +2
    -2
      parser/common/op_def/ir_pb_converter.cc
  30. +1
    -1
      parser/common/op_def/no_op_op.h
  31. +1
    -1
      parser/common/op_def/ref_switch_op.h
  32. +1
    -1
      parser/common/op_def/shape_n_op.cc
  33. +1
    -1
      parser/common/op_def/shape_n_op.h
  34. +1
    -1
      parser/common/op_def/var_is_initialized_op_op.cc
  35. +1
    -1
      parser/common/op_def/var_is_initialized_op_op.h
  36. +1
    -1
      parser/common/op_def/variable_op.cc
  37. +1
    -1
      parser/common/op_def/variable_op.h
  38. +1
    -3
      parser/common/op_map.cc
  39. +5
    -5
      parser/common/op_parser_factory.h
  40. +3
    -3
      parser/common/parser_api.cc
  41. +0
    -1270
      parser/common/parser_fp16_t.cc
  42. +0
    -653
      parser/common/parser_fp16_t.h
  43. +0
    -494
      parser/common/parser_types.cc
  44. +221
    -0
      parser/common/parser_utils.cc
  45. +37
    -0
      parser/common/parser_utils.h
  46. +0
    -83
      parser/common/pass_manager.cc
  47. +0
    -76
      parser/common/pass_manager.h
  48. +4
    -5
      parser/common/pre_checker.cc
  49. +1
    -1
      parser/common/pre_checker.h
  50. +0
    -3
      parser/common/proto/insert_op.proto
  51. +2
    -3
      parser/common/proto_file_parser.cc
  52. +15
    -9
      parser/common/register_tbe.cc
  53. +0
    -212
      parser/common/tbe_plugin_loader.cc
  54. +0
    -62
      parser/common/tbe_plugin_loader.h
  55. +0
    -78
      parser/common/thread_pool.cc
  56. +0
    -83
      parser/common/thread_pool.h
  57. +2
    -2
      parser/func_to_graph/proto_python_rule.mk
  58. +5
    -6
      parser/module.mk
  59. +2
    -3
      parser/onnx/module.mk
  60. +1
    -2
      parser/onnx/onnx_constant_parser.cc
  61. +0
    -1
      parser/onnx/onnx_data_parser.cc
  62. +12
    -9
      parser/onnx/onnx_parser.cc
  63. +0
    -3
      parser/proto/insert_op.proto
  64. +6
    -7
      parser/tensorflow/graph_functiondef.cc
  65. +1
    -1
      parser/tensorflow/graph_insert_trans_op.h
  66. +128
    -153
      parser/tensorflow/graph_optimizer.cc
  67. +9
    -10
      parser/tensorflow/graph_optimizer.h
  68. +1
    -1
      parser/tensorflow/iterator_fusion_pass.cc
  69. +0
    -3
      parser/tensorflow/proto/insert_op.proto
  70. +2
    -2
      parser/tensorflow/scope/scope_pass_manager.cc
  71. +2
    -3
      parser/tensorflow/tensorflow_arg_parser.cc
  72. +0
    -4
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  73. +2
    -4
      parser/tensorflow/tensorflow_constant_parser.cc
  74. +2
    -5
      parser/tensorflow/tensorflow_data_parser.cc
  75. +0
    -3
      parser/tensorflow/tensorflow_enter_parser.cc
  76. +2
    -7
      parser/tensorflow/tensorflow_fill_parser.cc
  77. +1
    -2
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  78. +4
    -4
      parser/tensorflow/tensorflow_fusion_op_parser.cc
  79. +18
    -19
      parser/tensorflow/tensorflow_fusionop_util.cc
  80. +1
    -1
      parser/tensorflow/tensorflow_fusionop_util.h
  81. +0
    -4
      parser/tensorflow/tensorflow_identity_parser.cc
  82. +0
    -2
      parser/tensorflow/tensorflow_merge_parser.cc
  83. +0
    -1
      parser/tensorflow/tensorflow_no_op_parser.cc
  84. +231
    -226
      parser/tensorflow/tensorflow_parser.cc
  85. +3
    -3
      parser/tensorflow/tensorflow_parser_register.h
  86. +0
    -1
      parser/tensorflow/tensorflow_ref_switch_parser.cc
  87. +3
    -4
      parser/tensorflow/tensorflow_reshape_parser.cc
  88. +0
    -1
      parser/tensorflow/tensorflow_shape_n_parser.cc
  89. +4
    -6
      parser/tensorflow/tensorflow_squeeze_parser.cc
  90. +10
    -11
      parser/tensorflow/tensorflow_util.cc
  91. +1
    -1
      parser/tensorflow/tensorflow_util.h
  92. +0
    -2
      parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc
  93. +0
    -1
      parser/tensorflow/tensorflow_variable_v2_parser.cc

+ 1
- 1
.gitmodules View File

@@ -1,4 +1,4 @@
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = development
branch = master

+ 25
- 3
README.md View File

@@ -9,7 +9,7 @@ parser以动态库的方式被调用。

### 源码安装

进行源码编译前,确保系统满足以下要求:
Parser支持由源码编译,进行源码编译前,首先确保你有昇腾910 AI处理器的环境进行源码编译前,确保系统满足以下要求:

- GCC >= 7.3.0
- CMake >= 3.14.0
@@ -22,19 +22,41 @@ parser以动态库的方式被调用。
```
git clone https://gitee.com/ascend/parser.git
cd parser
git submodule init && git submodule update
```

#### 源码编译

在parser根目录执行以下命令编译:
```
// 正在补充
目前parser需要集成到mindspore/graphengine中使用,暂不支持独立编译,解耦独立编译正在开发中,敬请期待;
```



## 贡献

欢迎参与贡献。

## 路标

以下将展示graphenine/parser近期的计划,我们会根据用户的反馈诉求,持续调整计划的优先级。

总体而言,我们会努力在以下几个方面不断改进。

1、完备性:Cast/ConcatV2算子支持输入数据类型为int64的常量折叠;

2、完备性:onnx parser支持一对多映射;

3、架构优化:ATC解耦并迁移至parser;

4、易用性:提供tensorflow训练的checkpoint文件转pb文件的一键式转化工具;

5、易用性:提供一键式本地编译环境构建工具;

6、可维测:ATC转换生成的om模型包含框架信息、cann版本信息和芯片信息等;

热忱希望各位在用户社区加入讨论,并贡献您的建议。

## Release Notes

Release Notes请参考[RELEASE](RELEASE.md)。


+ 24
- 3
build.sh View File

@@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/"
usage()
{
echo "Usage:"
echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c]"
echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off]"
echo ""
echo "Options:"
echo " -h Print usage"
@@ -33,9 +33,21 @@ usage()
echo " -t Build and execute ut"
echo " -c Build ut with coverage tag"
echo " -v Display build command"
echo " -S Enable enable download cmake compile dependency from gitee , default off"
echo "to be continued ..."
}

# check value of input is 'on' or 'off'
# usage: check_on_off arg_value arg_name
check_on_off()
{
if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then
echo "Invalid value $1 for option -$2"
usage
exit 1
fi
}

# parse and set options
checkopts()
{
@@ -46,8 +58,9 @@ checkopts()
ENABLE_GE_ST="off"
ENABLE_GE_COV="off"
GE_ONLY="on"
ENABLE_GITEE="off"
# Process the options
while getopts 'ustchj:v' opt
while getopts 'ustchj:vS:' opt
do
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]')
case "${opt}" in
@@ -77,6 +90,11 @@ checkopts()
v)
VERBOSE="VERBOSE=1"
;;
S)
check_on_off $OPTARG S
ENABLE_GITEE="$OPTARG"
echo "enable download from gitee"
;;
*)
echo "Undefined option: ${opt}"
usage
@@ -119,6 +137,10 @@ build_parser()
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON"
fi

if [[ "X$ENABLE_GITEE" = "Xon" ]]; then
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON"
fi

CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}"
echo "${CMAKE_ARGS}"
cmake ${CMAKE_ARGS} ..
@@ -186,7 +208,6 @@ generate_package()
done

find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libc_sec.so" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \;
find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libregister.a" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \;

tar -cf parser_lib.tar fwkacllib acllib atc
}


+ 9
- 1
cmake/external_libs/gflags.cmake View File

@@ -11,8 +11,16 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()

if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz")
set(MD5 "")
else()
set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz")
set(MD5 "")
endif ()

ExternalProject_Add(gflags_build
URL https://github.com/gflags/gflags/archive/v2.2.2.tar.gz
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${PARSER_DIR}/../../third_party/gflags/src/gflags-2.2.2
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR>


+ 11
- 2
cmake/external_libs/json.cmake View File

@@ -5,8 +5,17 @@ endif()
include(ExternalProject)

set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include)
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip")
set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7")
set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include")
else()
set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip")
set(MD5 "0dc903888211db3a0f170304cd9f3a89")
set(JSON_INCLUDE_DIR ${JSON_SRC_DIR})
endif ()
ExternalProject_Add(json_build
URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip
URL ${REQ_URL}
#URL /home/txd/workspace/cloud_code/pkg/include.zip
SOURCE_DIR ${JSON_SRC_DIR}
CONFIGURE_COMMAND ""
@@ -17,7 +26,7 @@ ExternalProject_Add(json_build


add_library(json INTERFACE)
target_include_directories(json INTERFACE ${JSON_SRC_DIR})
target_include_directories(json INTERFACE ${JSON_INCLUDE_DIR})
add_dependencies(json json_build)

#set(HAVE_JSON TRUE CACHE BOOL "json build add")


+ 9
- 1
cmake/external_libs/onnx.cmake View File

@@ -6,8 +6,16 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx)
set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto)
file(MAKE_DIRECTORY ${ONNX_PROTO_DIR})

if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz")
set(MD5 "1bdbcecdd68ea8392630467646776e02")
else()
set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz")
set(MD5 "512f2779d6215d4a36f366b6b9acdf1e")
endif ()

ExternalProject_Add(onnx
URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz
URL ${REQ_URL}
#URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz
#URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345
#SOURCE_DIR ${ONNX_SRC_DIR}


+ 9
- 1
cmake/external_libs/protobuf_shared.cmake View File

@@ -11,10 +11,18 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()

if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
endif ()

set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protobuf_build
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0
#DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 <SOURCE_DIR>


+ 9
- 1
cmake/external_libs/protobuf_static.cmake View File

@@ -8,11 +8,19 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()

if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
endif ()

set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static)
ExternalProject_Add(protobuf_static_build
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${PARSER_DIR}/../../third_party/protobuf/src/protobuf-3.8.0
CONFIGURE_COMMAND ${CMAKE_COMMAND}


+ 10
- 1
cmake/external_libs/protoc.cmake View File

@@ -12,10 +12,19 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz")
set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(MD5 "3d9e32700639618a4d2d342c99d4507a")
endif ()
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2")
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack")
ExternalProject_Add(protoc_build
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz
URL ${REQ_URL}
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz
#SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0
CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit cc9de48a7779cf95cab90a23db608421a691fd12
Subproject commit cba1ba3dbc8d933cd08ccd0dff586112c9501b75

+ 1
- 1
parser/CMakeLists.txt View File

@@ -131,7 +131,7 @@ target_compile_options(fmk_parser_stub PRIVATE
)

target_compile_definitions(fmk_parser_stub PRIVATE
$<$<STREQUAL:${PRODUCT_SIDE},host>:FMK_SUPPORT_DUMP>
$<$<OR:$<STREQUAL:${PRODUCT_SIDE},host>,$<STREQUAL:${ENABLE_OPEN_SRC},True>>:FMK_SUPPORT_DUMP>
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_HOST_INFER


+ 8
- 9
parser/caffe/caffe_custom_parser_adapter.cc View File

@@ -18,17 +18,16 @@
#include <memory>
#include <vector>
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/omg_inner_types.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/utils/graph_utils.h"
#include "parser/common/op_parser_factory.h"
#include "register/op_registry.h"

using domi::ParseParamByOpFunc;
using domi::ParseParamFunc;
using domi::ParseParamByOpFunc;
using std::vector;

namespace ge {
@@ -55,8 +54,8 @@ Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPt
}

Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) {
GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(),
op_src.GetOpType().c_str());
GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ",
op_src.GetName().c_str(), op_src.GetOpType().c_str());
GE_CHECK_NOTNULL(op_dest);

ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType());
@@ -86,7 +85,7 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr
bool update_in_turn = (static_cast<int64_t >(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size()));
int start_pos = layer->bottom_size();
for (int i = 0; i < layer->blobs_size(); ++i) {
ge::GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>();
ge::GeTensorPtr weight = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(weight);
GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i,
layer->name().c_str());
@@ -98,14 +97,14 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr
bias_en = fc_params_src.bias_term(););
auto bias_shape = weight->MutableTensorDesc().GetShape();
// The num 0, 1, 2, 3 represet the dim index.
bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) &&
bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::DIM_DEFAULT_SIZE) &&
bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1;
if (matched) {
weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)}));
}
matched = layer->type() == kInnerProduct && i == 0 &&
bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) &&
bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1;
bias_shape.GetDimNum() == static_cast<size_t>(ge::DIM_DEFAULT_SIZE) && bias_shape.GetDim(0) == 1 &&
bias_shape.GetDim(1) == 1;
if (matched) {
weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)}));
}


+ 10
- 10
parser/caffe/caffe_data_parser.cc View File

@@ -18,15 +18,13 @@
#include <unordered_map>
#include <utility>
#include "common/debug/log.h"
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "omg/omg_inner_types.h"
#include "parser/common/op_parser_factory.h"

using namespace ge::parser;

namespace ge {
Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims,
ge::OpDescPtr &op) {
@@ -50,10 +48,10 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
GE_CHECK_NOTNULL(layer);
GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str());

if (layer->type() == ge::parser::INPUT_TYPE) {
if (layer->type() == ge::INPUT_TYPE) {
GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed",
layer->name().c_str(), layer->type().c_str());
} else if(layer->type() == ge::parser::DUMMY_DATA) {
} else if(layer->type() == ge::DUMMY_DATA) {
GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed",
layer->name().c_str(), layer->type().c_str());
} else {
@@ -77,12 +75,14 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l
}
for (int i = 0; i < input_param.shape_size(); i++) {
const domi::caffe::BlobShape &blob_shape = input_param.shape(i);

vector<int64_t> shape;
unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
unordered_map<string, vector<int64_t>> &shape_map = domi::GetContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
}

string name = layer->name();
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name));
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s",
@@ -90,7 +90,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l
}
} else {
// Get from external input
const ge::ParserContext &ctx = GetParserContext();
const ge::OmgContext &ctx = domi::GetContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name();
auto search = input_dims.find(name);
@@ -124,7 +124,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete
const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i);

vector<int64_t> shape;
unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
unordered_map<string, vector<int64_t>> &shape_map = domi::GetContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
@@ -137,7 +137,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete
}
} else {
// Get from external input
const ge::ParserContext &ctx = GetParserContext();
const ge::OmgContext &ctx = domi::GetContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name();
auto search = input_dims.find(name);


+ 0
- 3
parser/caffe/caffe_op_parser.cc View File

@@ -18,9 +18,6 @@
#include <memory>
#include "parser/common/op_parser_factory.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/omg/parser/parser_types.h"

using namespace ge::parser;

using domi::CAFFE;



+ 165
- 109
parser/caffe/caffe_parser.cc View File

@@ -20,17 +20,16 @@
#include <iostream>
#include <sstream>
#include <memory>
#include <algorithm>
#include "parser/common/convert/pb2json.h"
#include "common/convert/pb2json.h"
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/model_saver.h"
#include "common/op_map.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "common/ge_types.h"
#include "common/string_util.h"
#include "external/graph/operator_factory.h"
#include "external/parser/caffe_parser.h"
#include "external/ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "graph/optimize/common/params.h"
#include "graph/utils/graph_utils.h"
@@ -47,8 +46,6 @@
#include "parser/caffe/caffe_op_parser.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/model_saver.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/proto_file_parser.h"
#include "register/op_registry.h"
@@ -58,7 +55,7 @@ using domi::caffe::NetParameter;
using domi::ParseParamByOpFunc;
using ge::caffe_op_map;
using ge::CaffeOpParser;
using ge::parser::ModelSaver;
using ge::ModelSaver;
using ge::OpParser;
using ge::OpParserFactory;
using ge::Pb2Json;
@@ -77,7 +74,7 @@ using std::ifstream;
namespace ge {
graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) {
GE_CHECK_NOTNULL(model_file);
GetParserContext().type = domi::CAFFE;
domi::GetContext().type = domi::CAFFE;
std::map<string, string> options;
options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE)));

@@ -86,7 +83,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
(void)acl_graph_parse_util.AclParserInitialize(options);

// Create an empty computegraph
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph");
ge::ComputeGraphPtr compute_graph = ge::MakeShared<ge::ComputeGraph>("tmpGraph");
GE_CHECK_NOTNULL(compute_graph);

graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
@@ -108,10 +105,6 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file,
return ret;
}
GELOGI("Weights parse success. graph: %s", graph.GetName().c_str());
if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) {
GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str());
return ge::FAILED;
}
return ge::SUCCESS;
}
} // namespace ge
@@ -155,15 +148,14 @@ const std::string kRepeated = "repeated";
const std::string kRequired = "required";
const std::string kCustom = "custom";
const std::string kBuiltin = "built-in";
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT,
ge::parser::NETOUTPUT};
std::vector<std::string> kAddTensorIrSkipNodes = {ge::DATA, ge::YOLODETECTIONOUTPUT, ge::NETOUTPUT};
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"};
const std::set<std::string> kCaffeProtoLayerCommonField = {"name", "type", "bottom", "top", "phase",
"loss_weight", "param", "blobs", "propagate_down",
"include", "exclude"};
Status CheckPathValid(const char *model_path, const string &custom_proto, string &custom_proto_path,
string &custom_proto_name) {
string path_model = ge::parser::RealPath(model_path);
string path_model = ge::RealPath(model_path);
if (path_model.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {model_path, strerror(errno)});
GELOGE(FAILED, "Invalid path of model: %s", model_path);
@@ -219,7 +211,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->set_type(ge::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
@@ -248,7 +240,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(proto_message.input(i));
layer->set_type(ge::parser::INPUT_TYPE);
layer->set_type(ge::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
@@ -263,7 +255,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
input_data_flag = true;
}
} else {
const ge::ParserContext &ctx = ge::GetParserContext();
const ge::OmgContext &ctx = domi::GetContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
for (int i = 0; i < proto_message.input_size(); i++) {
string name = proto_message.input(i);
@@ -278,7 +270,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(name);
layer->set_type(ge::parser::INPUT_TYPE);
layer->set_type(ge::INPUT_TYPE);
layer->add_top(proto_message.input(i));

domi::caffe::InputParameter *input_param = layer->mutable_input_param();
@@ -343,7 +335,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons

Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto,
const string &caffe_proto, vector<ge::Operator> &operators) {
string custom_proto_path = ge::parser::RealPath(custom_proto.c_str());
string custom_proto_path = ge::RealPath(custom_proto.c_str());
if (custom_proto_path.empty()) {
GELOGW("Valid custom proto: %s does not exist, skip parsing custom proto", custom_proto.c_str());
return SUCCESS;
@@ -749,27 +741,27 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *
}

void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) {
auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name);
if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) {
auto iter_node_name = domi::GetContext().out_nodes_map.find(layer_name);
if (iter_node_name != domi::GetContext().out_nodes_map.end()) {
iter_node_name->second.emplace_back(top_index);
} else {
std::vector<int32_t> index_v;
index_v.emplace_back(top_index);
ge::GetParserContext().out_nodes_map.emplace(layer_name, index_v);
domi::GetContext().out_nodes_map.emplace(layer_name, index_v);
}
ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index));
domi::GetContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index));
}

Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) {
if (ge::GetParserContext().user_out_nodes_top_vec.empty()) {
if (domi::GetContext().user_out_nodes_top_vec.empty()) {
return SUCCESS;
}

ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_nodes.clear();
domi::GetContext().out_nodes_map.clear();
domi::GetContext().user_out_nodes.clear();
int32_t layer_count = proto_message.layer_size();
const std::vector<string> &user_out_nodes_top_vec =
ge::GetParserContext().user_out_nodes_top_vec;
domi::GetContext().user_out_nodes_top_vec;

for (const auto &top_name : user_out_nodes_top_vec) {
bool find_node_falg = false;
@@ -808,6 +800,10 @@ Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter

Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer,
std::map<std::string, std::string> &inplace_blob_name_remapping) {
if (layer.type() == ge::NETOUTPUT) {
return SUCCESS;
}

if (layer.top_size() <= 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()});
GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str());
@@ -966,9 +962,9 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C
} else {
op_type = layer.type();
// User defined duplicate name operator processing
auto m_iter = ge::GetParserContext().op_conf_map.find(op_type);
auto m_iter = domi::GetContext().op_conf_map.find(op_type);
// User specified configuration item found
if (m_iter != ge::GetParserContext().op_conf_map.end()) {
if (m_iter != domi::GetContext().op_conf_map.end()) {
op_type = m_iter->second;
}
// General layer layer, search optype
@@ -1057,7 +1053,7 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C
Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) {
GE_CHECK_NOTNULL(op_desc);
// Data node input and output tensordesc added in parserparam
if (op_desc->GetType() == ge::parser::DATA) {
if (op_desc->GetType() == ge::DATA) {
return SUCCESS;
}

@@ -1077,7 +1073,7 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom
}

// yolo v2 YoloDetectionOutput
if (op_desc->GetType() == ge::parser::YOLODETECTIONOUTPUT) {
if (op_desc->GetType() == ge::YOLODETECTIONOUTPUT) {
ge::GeTensorDesc input_tensor;
GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor));
GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor));
@@ -1086,13 +1082,41 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom
"while it's original input num is: %d",
layer.bottom_size());
}

// Netoutput node processing
if (op_desc->GetType() == ge::NETOUTPUT) {
size_t input_output_tensor_num = 0;
if (!domi::GetContext().user_out_nodes.empty()) {
// User specified output
input_output_tensor_num = domi::GetContext().user_out_nodes.size();
} else {
for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) {
auto b_iter = bottom_blobs_map_.find(t_iter->first);
// Find the output node of the network
if (b_iter == bottom_blobs_map_.end()) {
input_output_tensor_num += top_blobs_map_[t_iter->first].size();
}
}
}
// add tensordesc
GELOGD(
"Current op type is NETOUTPUT, add additional input&output num: %zu."
"while it's original input num is: %d, output num is: %d",
input_output_tensor_num, layer.bottom_size(), output_tensor_num);
for (size_t i = 0; i < input_output_tensor_num; i++) {
ge::GeTensorDesc input_tensor;
GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor));
ge::GeTensorDesc output_tensor;
GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor));
}
}
return SUCCESS;
}

Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type) {
if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) {
op_desc = ge::parser::MakeShared<ge::OpDesc>(layer.name(), op_type);
op_desc = ge::MakeShared<ge::OpDesc>(layer.name(), op_type);
GE_CHECK_NOTNULL(op_desc);
Status ret = AddTensorDescToOpDesc(op_desc, layer);
if (ret != SUCCESS) {
@@ -1224,8 +1248,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) {

bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) {
bool ret = false;
auto iter = ge::GetParserContext().out_nodes_map.find(op_name);
if (iter != ge::GetParserContext().out_nodes_map.end()) {
auto iter = domi::GetContext().out_nodes_map.find(op_name);
if (iter != domi::GetContext().out_nodes_map.end()) {
std::vector<int32_t> tmp_index_v;
for (int32_t id : iter->second) {
if (index == id) {
@@ -1236,40 +1260,53 @@ bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) {
}
// To prevent specifying network output again in the build phase, need to delete the output node in the map list.
if (ret) {
ge::GetParserContext().out_nodes_map.erase(op_name);
ge::GetParserContext().out_nodes_map.emplace(op_name, tmp_index_v);
domi::GetContext().out_nodes_map.erase(op_name);
domi::GetContext().out_nodes_map.emplace(op_name, tmp_index_v);
}
}
return ret;
}

Status CaffeModelParser::AddUserOutNodesTop() {
Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::NETOUTPUT);
if (net_output_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Can not find netoutput node.");
return INTERNAL_ERROR;
}
uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize();
int32_t index = 0;
const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes;
int net_output_num = user_out_nodes.size();
for (const auto &out_pair : user_out_nodes) {
auto layer_iter = layer_tops_map_.find(out_pair.first);
std::vector<std::pair<std::string, int32_t>> &user_out_nodes = domi::GetContext().user_out_nodes;
for (auto &out_pair : user_out_nodes) {
auto node_iter = node_map.find(out_pair.first);
GELOGI("Add to output, node name: %s", out_pair.first.c_str());
if (layer_iter != layer_tops_map_.end()) {
if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) {
if (node_iter != node_map.end()) {
if ((static_cast<uint32_t>(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) ||
(static_cast<uint32_t>(index) >= net_output_num)) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"},
{out_pair.first.c_str(), std::to_string(out_pair.second),
std::to_string((layer_iter->second).size()), std::to_string(index),
std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index),
std::to_string(net_output_num)});
GELOGE(INTERNAL_ERROR,
"Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput"
"input_index:%d should < %u.",
out_pair.first.c_str(), out_pair.second, (layer_iter->second).size(), index,
out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index,
net_output_num);
return INTERNAL_ERROR;
}

string top_name = layer_iter->second[out_pair.second];
auto top_node_iter = node_map.find(out_pair.first);
if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_top_names.push_back(top_name);
GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str());
GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(),
out_pair.second, net_output_node->GetName().c_str(), index);
ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second);
GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(in_archor_ptr);
if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"},
{node_iter->second->GetName(), net_output_node->GetName()});
GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(),
net_output_node->GetName().c_str());
return INTERNAL_ERROR;
}
++index;
} else {
@@ -1281,7 +1318,13 @@ Status CaffeModelParser::AddUserOutNodesTop() {
return SUCCESS;
}

Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_message) {
Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
ge::NodePtr node = graph->FindFirstNodeMatchType(ge::NETOUTPUT);

GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front.");

int32_t index = 0;
for (int32_t i = 0; i < proto_message.layer_size(); i++) {
const domi::caffe::LayerParameter &layer = proto_message.layer(i);

@@ -1291,7 +1334,6 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes

for (int i = 0; i < layer.top_size(); i++) {
string top = layer.top(i);
string top_origin = top;
// Handling 'inplace' scenarios
if (IsInplaceTopBlob(layer, top)) {
top = RemapTopNameByLayer(layer, top, i);
@@ -1313,9 +1355,21 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes
auto top_node_iter = node_map.find(layer.name());
GELOGI("output in top_blob: %s", layer.name().c_str());
if (top_node_iter != node_map.end()) {
ge::GetParserContext().out_top_names.push_back(top_origin);
ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i));
GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str());
// add edge
// Output node, output index, input node, input index
GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i,
node->GetName().c_str(), index);
ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i);
GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(in_archor_ptr);
GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS,
ErrorManager::GetInstance().ATCReportErrMessage(
"E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()});
GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].",
top_node_iter->second->GetName().c_str(), node->GetName().c_str());
return INTERNAL_ERROR;);
index++;
}
}
}
@@ -1370,7 +1424,7 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) {

// validate opname
string mode = "^[A-Za-z0-9./_-]+$";
if (!ge::parser::ValidateStr(layer.name(), mode)) {
if (!ge::ValidateStr(layer.name(), mode)) {
ErrorManager::GetInstance().ATCReportErrMessage("E11018", {"opname"}, {layer.name()});
GELOGE(ge::FAILED,
"Parse caffe pbtxt validate op[%s] failed, opname can only contain "
@@ -1399,7 +1453,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
domi::caffe::NetParameter proto_message;

// Get Caffe network model information
if (!ge::parser::ReadProtoFromMem(data, static_cast<int>(size), &proto_message)) {
if (!ge::ReadProtoFromMem(data, static_cast<int>(size), &proto_message)) {
GELOGE(FAILED, "read proto from text ret fail");
return FAILED;
}
@@ -1429,6 +1483,12 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true;
GELOGE(FAILED, "ParseInput ret fail."));

// build output layer
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT);
layer->set_type(ge::NETOUTPUT);

int32_t layer_count = proto_message.layer_size();
std::map<std::string, std::string> inplace_blob_name_remapping;
// Map of operator name and occurrence times
@@ -1444,7 +1504,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());

CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true;
CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true;
GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str()));

// All layer names cannot be duplicate
@@ -1493,10 +1553,10 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co

GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail.");

if (!(ge::GetParserContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed.");
if (!(domi::GetContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed.");
} else {
GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail.");
GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail.");
}
GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail.");

@@ -1540,34 +1600,6 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer)
return;
}

Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &layer) {
string name = layer.name();
if (node_map.find(name) == node_map.end()) {
GELOGE(FAILED, "Node can not be found by layer name: %s", name.c_str());
return FAILED;
}

ge::NodePtr node = node_map[name];
GE_CHECK_NOTNULL(node);

if (node->GetType() == ge::parser::DATA) {
if (layer.top_size() != 1) {
GELOGE(FAILED, "Data layer[%s] top size must be 1, real size: %d", name.c_str(), layer.top_size());
return FAILED;
}

string top_name = layer.top(0);
auto data_top_names = ge::GetParserContext().data_top_names;
if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) {
GELOGE(FAILED, "Different data can not have same top name: %s.", top_name.c_str());
return FAILED;
}
ge::GetParserContext().data_top_names.push_back(top_name);
}

return SUCCESS;
}

Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) {
bool has_error = false;
GE_CHECK_NOTNULL(model_path);
@@ -1626,20 +1658,25 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true;
GELOGE(FAILED, "ParseInput ret fail."));

// build output layer
domi::caffe::LayerParameter *layer = proto_message.add_layer();
GE_CHECK_NOTNULL(layer);
layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT);
layer->set_type(ge::NETOUTPUT);

int32_t layer_count = proto_message.layer_size();

if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) {
if (!domi::GetContext().user_out_nodes_top_vec.empty()) {
GELOGW("The out_put info has top_name items.");
GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message),
"Caffe parser parse output node-top info failed.");
ge::GetParserContext().user_out_nodes_top_vec.clear();
domi::GetContext().user_out_nodes_top_vec.clear();
}

std::map<std::string, std::string> inplace_blob_name_remapping;
// Map of operator name and occurrence times
std::map<std::string, int32_t> layer_name_map;

GetParserContext().data_top_names.clear();
// <layername,paramnames>
std::map<std::string, std::vector<std::string>> layer_params_map;
// same param name set <paramnames,layernames>
@@ -1649,7 +1686,7 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap
GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.",
layer.name().c_str(), layer.type().c_str());

CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true;
CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true;
GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str()));

// All layer names cannot be duplicate
@@ -1686,11 +1723,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap

GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping),
"Caffe parser add blobs to map ret fail.");
if (SaveDataLayerTops(layer) != SUCCESS) {
GELOGE(FAILED, "Caffe parse: save data layer tops failed.");
return FAILED;
}
}

// Find a layer with the same param name and save it to graph
GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map),
"Caffe parser find share param layers map ret fail.");
@@ -1702,12 +1736,13 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap

GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail.");

if (!(ge::GetParserContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed.");
if (!(domi::GetContext().user_out_nodes.empty())) {
GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed.");
} else {
GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail.");
GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail.");
}
GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail.");
GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed.");

auto nodes = graph->GetDirectNode();
GELOGI("graph node size = %zu.", nodes.size());
@@ -1800,7 +1835,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::

// Resolve proto file to netparameter
NetParameter proto;
bool success = ge::parser::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto);
bool success = ge::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto);
if (!success) {
GELOGE(domi::PARSE_WEIGHTS_FAILED, "ReadProto from Memory fail");
return domi::PARSE_WEIGHTS_FAILED;
@@ -1848,7 +1883,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) {

GELOGD("caffe_proto_path:%s custom_proto_path:%s", caffe_proto_path.c_str(), custom_proto_path.c_str());
string fusion_proto_file;
string custom_proto_file = ge::parser::RealPath(custom_proto_path.c_str());
string custom_proto_file = ge::RealPath(custom_proto_path.c_str());
if (custom_proto_file.empty()) {
GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str());
fusion_proto_file = caffe_proto_path;
@@ -1860,7 +1895,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) {
}
}

string fusion_proto_path = ge::parser::RealPath(fusion_proto_file.c_str());
string fusion_proto_path = ge::RealPath(fusion_proto_file.c_str());
GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str());
if (fusion_proto_path.empty()) {
GELOGE(FAILED, "Fusion proto file path [%s]-[%s] is not real existed.", fusion_proto_file.c_str(),
@@ -1913,7 +1948,7 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con
google::protobuf::Message *message = proto->New();
GE_CHECK_NOTNULL(message);

if (!ge::parser::ReadProtoFromBinaryFile(weight_path, message)) {
if (!ge::ReadProtoFromBinaryFile(weight_path, message)) {
delete message;
message = nullptr;
ErrorManager::GetInstance().ATCReportErrMessage(
@@ -2303,7 +2338,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto &in_anchor_ptr : node->GetAllInDataAnchors()) {
if (op_desc->GetType() == ge::parser::DATA || op_desc->GetType() == ge::parser::CONSTANT) {
if (op_desc->GetType() == ge::DATA || op_desc->GetType() == ge::CONSTANT) {
continue;
}
auto index = in_anchor_ptr->GetIdx();
@@ -2418,6 +2453,27 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter &param, ge::Co
return SUCCESS;
}

Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) {
auto netout = graph->FindFirstNodeMatchType(ge::NETOUTPUT);
GE_CHECK_NOTNULL(netout);
for (const auto &in_anchor : netout->GetAllInDataAnchors()) {
auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor);
auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_out_data_node);
int idx = peer_out_data_anchor->GetIdx();
string node_name = peer_out_data_node->GetName();
auto layer_iter = layer_tops_map_.find(node_name);
if (layer_iter != layer_tops_map_.end()) {
domi::GetContext().out_top_names.push_back(layer_iter->second[idx]);
GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str());
} else {
GELOGW("The out node [%s] can not find its top.", node_name.c_str());
}
}
return SUCCESS;
}

Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) {
return SUCCESS;
}


+ 6
- 6
parser/caffe/caffe_parser.h View File

@@ -279,12 +279,12 @@ class CaffeModelParser : public domi::ModelParser {

/**
* @ingroup domi_omg
* @brief Add top name information to graph
* @param [in|out] proto_message
* @brief Add edge information to graph
* @param [in|out] graph graph for saving model information
* @return SUCCESS add successfully
* @return FAILED add failed
*/
Status AddOutputTop(const domi::caffe::NetParameter &proto_message);
Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph);

/**
* @ingroup domi_omg
@@ -324,7 +324,7 @@ class CaffeModelParser : public domi::ModelParser {
Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type);

Status AddUserOutNodesTop();
Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph);

std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index);

@@ -335,6 +335,8 @@ class CaffeModelParser : public domi::ModelParser {
Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op,
std::shared_ptr<ge::OpParser> &op_parser);

Status GetLeafNodeTops(ge::ComputeGraphPtr &graph);

void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer);

Status ReorderInput(domi::caffe::NetParameter &net);
@@ -343,8 +345,6 @@ class CaffeModelParser : public domi::ModelParser {

Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message);

Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer);

std::map<std::string, ge::NodePtr> node_map;

// key: blob name, value: layer name and index


+ 2
- 4
parser/caffe/caffe_reshape_parser.cc View File

@@ -17,16 +17,14 @@
#include "parser/caffe/caffe_reshape_parser.h"
#include <vector>
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/op/op_parser_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/graph_utils.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"
#include "proto/om.pb.h"

using namespace ge::parser;
using domi::CAFFE;

namespace ge {
@@ -109,7 +107,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) {
}

// construct GeTensorPtr
ge::GeTensorPtr constTensor = ge::parser::MakeShared<ge::GeTensor>();
ge::GeTensorPtr constTensor = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(constTensor);
constTensor->SetTensorDesc(const_desc);



+ 4
- 6
parser/common/CMakeLists.txt View File

@@ -8,8 +8,7 @@ set(SRC_LIST
"parser_inner_ctx.cc"
"proto_file_parser.cc"
"acl_graph_parser_util.cc"
"tbe_plugin_loader.cc"
"model_saver.cc"
"../../../ge/common/model_saver.cc"
"../tensorflow/tensorflow_custom_parser_adapter.cc"
"../tensorflow/tensorflow_fusion_custom_parser_adapter.cc"
"../tensorflow/tensorflow_fusion_op_parser.cc"
@@ -20,10 +19,9 @@ set(SRC_LIST
"op_def/op_schema.cc"
"op_def/operator.cc"
"op_map.cc"
"parser_types.cc"
"pass_manager.cc"
"parser_fp16_t.cc"
"thread_pool.cc"
"../../../ge/graph/passes/pass_manager.cc"
"../../../ge/common/thread_pool.cc"
"parser_utils.cc"
)

############ libparser_common.so ############


+ 8
- 275
parser/common/acl_graph_parser_util.cc View File

@@ -18,37 +18,20 @@

#include <dlfcn.h>
#include <cstdlib>
#include <fstream>
#include <regex.h>
#include <ctime>

#include "common/string_util.h"
#include "common/types.h"
#include "common/debug/log.h"
#include "common/ge/tbe_plugin_manager.h"
#include "common/op/ge_op_utils.h"
#include "common/util.h"

#include "ge/ge_api_types.h"
#include "graph/opsproto_manager.h"
#include "omg/parser/parser_inner_ctx.h"
#include "tbe_plugin_loader.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/parser_types.h"
#include "common/util/error_manager/error_manager.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"

using google::protobuf::io::CodedInputStream;
using google::protobuf::io::FileInputStream;
using google::protobuf::io::ZeroCopyInputStream;
using namespace ge::parser;

namespace {
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
const int kMaxBuffSize = 256;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M

static string GetSoPath() {
Dl_info dl_info;
if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) {
@@ -77,7 +60,7 @@ static void GetOpsProtoPath(string &opsproto_path) {
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
string path = path_env;
string file_path = ge::parser::RealPath(path.c_str());
string file_path = ge::RealPath(path.c_str());
if (file_path.empty()) {
GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str());
return;
@@ -125,7 +108,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) {
output_nodes_name.clear();
if (ge::GetParserContext().out_top_names.empty()) {
if (domi::GetContext().out_top_names.empty()) {
// tf process, no top name.
for (const auto output_node_info : output_nodes_info) {
std::string node_name = output_node_info.first->GetName();
@@ -159,7 +142,7 @@ domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) {

AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
compute_graph->SetGraphOutNodesInfo(output_nodes_info);
ge::GetParserContext().net_out_nodes = output_nodes_name;
domi::GetContext().net_out_nodes = output_nodes_name;
GELOGI("Set graph %s default output node success.", graph.GetName().c_str());
return SUCCESS;
}
@@ -211,7 +194,7 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
}

// load custom op plugin
TBEPluginLoader::Instance().LoadPluginSo(options);
TBEPluginManager::Instance().LoadPluginSo(options);

// load and save custom op proto for prediction
(void)LoadOpsProtoLib();
@@ -239,254 +222,4 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s
GELOGT(TRACE_STOP, "AclParserInitialize finished");
return SUCCESS;
}
namespace parser {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
if (path == nullptr) {
GELOGE(ge::FAILED, "path pointer is NULL.");
return "";
}
if (strlen(path) >= PATH_MAX) {
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)});
GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX);
return "";
}
// Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible
std::string res;
char resolved_path[PATH_MAX] = {0};
if (realpath(path, resolved_path) != nullptr) {
res = resolved_path;
}

return res;
}

// Get file length
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null.");

std::string real_path = RealPath(input_file.c_str());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"},
{input_file, strerror(errno)});
return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
return -1, "File[%s] size is 0, not valid.", input_file.c_str());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return -1, "File[%s] size %lld is out of limit: %d.",
input_file.c_str(), file_length, kMaxFileSizeLimit);
return static_cast<long>(file_length);
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
struct timeval tv{};
int ret = gettimeofday(&tv, nullptr);
GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
return static_cast<uint64_t>(total_use_time);
}

static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr,
return false, "incorrect parameter. nullptr == proto");

coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold);
return proto->ParseFromCodedStream(&coded_stream);
}

/** @ingroup domi_common
* @brief Read all data from binary file
* @param [in] file_name File path
* @param [out] buffer The address of the output memory, which needs to be released by the caller
* @param [out] length Output memory size
* @return false fail
* @return true success
*/
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer,
int &length) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr");

std::string real_path = RealPath(file_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name);

std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate);
if (!file.is_open()) {
GELOGE(ge::FAILED, "Read file %s failed.", file_name);
return false;
}

length = static_cast<int>(file.tellg());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0");

file.seekg(0, std::ios::beg);

*buffer = new(std::nothrow) char[length]();
GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed.");

file.read(*buffer, length);
file.close();
return true;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr),
return false,
"Input parameter file or proto is nullptr!");

std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
return false, "pb file path '%s' not valid", file);

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");

std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"});
GELOGE(ge::FAILED, "Open real path[%s] failed.", file);
return false;
}

google::protobuf::io::IstreamInputStream istream(&fs);
google::protobuf::io::CodedInputStream coded_stream(&istream);

bool ret = ReadProtoFromCodedInputStream(coded_stream, proto);

fs.close();

if (!ret) {
ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file});
GELOGE(ge::FAILED, "Parse file[%s] failed.", file);
return ret;
}

return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false,
"incorrect parameter. proto is nullptr || data is nullptr || size is 0");

google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size);
return ReadProtoFromCodedInputStream(coded_stream, proto);
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file,
google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false,
"incorrect parameter. nullptr == file || nullptr == message");

std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"},
{file, strerror(errno)});
return false, "Path[%s]'s realpath is empty, errmsg[%s]", file,
strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");

std::ifstream fs(real_path.c_str(), std::ifstream::in);

if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file});
GELOGE(ge::FAILED,
"Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file);
return false;
}

google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC(!ret,
ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file});
GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, "
"please check whether the file is a valid protobuf format file.", file));
fs.close();

return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size,
google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false,
"incorrect parameter. data is nullptr || message is nullptr");
std::string str(data, static_cast<size_t>(size));
std::istringstream fs(str);

google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC(
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));

return ret;
}

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
Status GetOriginalType(const ge::NodePtr &node, string &type) {
GE_CHECK_NOTNULL(node);
type = node->GetType();
GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
GE_CHECK_NOTNULL(node->GetOpDesc());
bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
if (!ret) {
GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
return INTERNAL_ERROR;
}
GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) {
char ebuff[kMaxBuffSize];
regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB;
int ret = regcomp(&reg, mode.c_str(), cflags);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGW("regcomp failed, reason: %s", ebuff);
regfree(&reg);
return true;
}

ret = regexec(&reg, str.c_str(), 0, nullptr, 0);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
regfree(&reg);
return false;
}

regfree(&reg);
return true;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() {
std::time_t now = std::time(nullptr);
std::tm *ptm = std::localtime(&now);
if (ptm == nullptr) {
GELOGE(ge::FAILED, "Localtime failed.");
return "";
}

const int kTimeBufferLen = 32;
char buffer[kTimeBufferLen + 1] = {0};
// format: 20171122042550
std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
return std::string(buffer);
}
} // namespace parser
} // namespace ge

+ 2
- 192
parser/common/acl_graph_parser_util.h View File

@@ -19,17 +19,10 @@

#include <map>
#include <string>
#include <google/protobuf/text_format.h>
#include <sstream>

#include "framework/omg/parser/parser_types.h"
#include "register/register_error_codes.h"
#include "common/types.h"
#include "graph/utils/graph_utils.h"

namespace ge {

using google::protobuf::Message;

class AclGrphParseUtil {
public:
AclGrphParseUtil() {}
@@ -45,189 +38,6 @@ class AclGrphParseUtil {
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
};

namespace parser {
///
/// @ingroup: domi_common
/// @brief: get length of file
/// @param [in] input_file: path of file
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
///
extern long GetFileLength(const std::string &input_file);

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Obtains the absolute time (timestamp) of the current system.
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestamp();

///
/// @ingroup domi_common
/// @brief Reads all data from a binary file.
/// @param [in] file_name path of file
/// @param [out] buffer Output memory address, which needs to be released by the caller.
/// @param [out] length Output memory size
/// @return false fail
/// @return true success
///
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);

///
/// @ingroup domi_common
/// @brief proto file in bianary format
/// @param [in] file path of proto file
/// @param [out] proto memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromBinaryFile(const char *file, Message *proto);

///
/// @ingroup domi_common
/// @brief Reads the proto structure from an array.
/// @param [in] data proto data to be read
/// @param [in] size proto data size
/// @param [out] proto Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromArray(const void *data, int size, Message *proto);

///
/// @ingroup domi_proto
/// @brief Reads the proto file in the text format.
/// @param [in] file path of proto file
/// @param [out] message Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromText(const char *file, google::protobuf::Message *message);

bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
domi::Status GetOriginalType(const ge::NodePtr &node, string &type);

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();

template <typename T, typename... Args>
static inline std::shared_ptr<T> MakeShared(Args &&... args) {
typedef typename std::remove_const<T>::type T_nc;
std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
return ret;
}

/// @ingroup math_util
/// @brief check whether int64 multiplication can result in overflow
/// @param [in] a multiplicator
/// @param [in] b multiplicator
/// @return Status
inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) {
if (a > 0) {
if (b > 0) {
if (a > (INT64_MAX / b)) {
return domi::FAILED;
}
} else {
if (b < (INT64_MIN / a)) {
return domi::FAILED;
}
}
} else {
if (b > 0) {
if (a < (INT64_MIN / b)) {
return domi::FAILED;
}
} else {
if ((a != 0) && (b < (INT64_MAX / a))) {
return domi::FAILED;
}
}
}
return domi::SUCCESS;
}

/// @ingroup math_util
/// @brief check whether int64 multiplication can result in overflow
/// @param [in] a multiplicator
/// @param [in] b multiplicator
/// @return Status
inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) {
if (a == 0 || b == 0) {
return domi::SUCCESS;
}
if (a > 0) {
if (a > (INT64_MAX / b)) {
return domi::FAILED;
}
} else {
if (a < (INT64_MIN / b)) {
return domi::FAILED;
}
}
return domi::SUCCESS;
}

#define PARSER_INT64_MULCHECK(a, b) \
if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \
GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \
static_cast<int64_t>(b)); \
return INTERNAL_ERROR; \
}

#define PARSER_INT64_UINT32_MULCHECK(a, b) \
if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \
GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \
static_cast<uint32_t>(b)); \
return INTERNAL_ERROR; \
}
} // namespace parser
} // namespace ge

/*lint --emacro((773),GE_TIMESTAMP_START)*/
/*lint -esym(773,GE_TIMESTAMP_START)*/
#define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp()

#define PARSER_TIMESTAMP_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#endif // ACL_GRAPH_PARSE_UTIL_
#endif // ACL_GRAPH_PARSE_UTIL_

+ 3
- 3
parser/common/convert/pb2json.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -115,7 +115,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr

case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) {
json[field->name()] = str;
} else {
json[field->name()] = reflection->GetFloat(message, field);
@@ -148,7 +148,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {
uint8_t *value = 0;
value = reinterpret_cast<uint8_t *>(&temp_value);
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) {
GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str());
continue;
}


+ 10
- 10
parser/common/convert/pb2json.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,8 +17,8 @@
// File: pb2json.h
// Description: This header file for protobuf message and json interconversion

#ifndef PARSER_COMMON_CONVERT_PB2JSON_H_
#define PARSER_COMMON_CONVERT_PB2JSON_H_
#ifndef GE_COMMON_CONVERT_PB2JSON_H_
#define GE_COMMON_CONVERT_PB2JSON_H_
#include <functional>
#include <memory>
#include <set>
@@ -38,12 +38,12 @@ using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor;
class Pb2Json {
public:
/**
* @ingroup domi_omg
* @brief Transfer protobuf object to JSON object
* @param [out] json Converted JSON object
* @return void success
* @author
*/
* @ingroup domi_omg
* @brief Transfer protobuf object to JSON object
* @param [out] json Converted JSON object
* @return void success
* @author
*/
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json,
bool enum2str = false);

@@ -65,4 +65,4 @@ class Pb2Json {
};
} // namespace ge

#endif // PARSER_COMMON_CONVERT_PB2JSON_H_
#endif // GE_COMMON_CONVERT_PB2JSON_H_

+ 9
- 9
parser/common/data_op_parser.cc View File

@@ -18,7 +18,7 @@
#include <cstdlib>
#include "common/debug/log.h"
#include "common/op/ge_op_utils.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/math/math_util.h"
#include "common/util.h"
#include "graph/utils/type_utils.h"
#include "omg/omg.h"
@@ -36,7 +36,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &
GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null");

const string &data_op_name = op->GetName();
GetParserContext().input_dims.emplace(data_op_name, shape);
domi::GetContext().input_dims.emplace(data_op_name, shape);

int64_t attr_type = 0;
ge::DataType data_type;
@@ -51,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &

ge::GeTensorDesc i_tensor_desc;
ge::GeTensorDesc o_tensor_desc;
const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = GetParserContext().input_nodes_format_map;
const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = domi::GetContext().input_nodes_format_map;
auto map_iter = input_nodes_format_map.find(data_op_name);
if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) {
// Input 5D NC1HWC0
@@ -80,9 +80,9 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &
"Init ND Output Tensor failed");
}
}
i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format));
i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format));
o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format));
if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) {
GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str());
return FAILED;
@@ -128,10 +128,10 @@ Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType dat
}
uint32_t type_size = 0;
if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) {
PARSER_INT64_UINT32_MULCHECK(size, type_size);
FMK_INT64_UINT32_MULCHECK(size, type_size);
size *= type_size;
} else {
PARSER_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
size *= sizeof(float);
}
ge::TensorUtils::SetSize(tensor_desc, size);
@@ -169,7 +169,7 @@ Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorD
if (input.GetShape().GetDim(0) != -1) {
size = input.GetShape().GetShapeSize();
}
PARSER_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
ge::TensorUtils::SetSize(input, size * sizeof(float));

return SUCCESS;


+ 1
- 1
parser/common/data_op_parser.h View File

@@ -21,7 +21,7 @@
#include <vector>
#include "common/debug/log.h"
#include "common/op/attr_value_util.h"
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "omg/omg_inner_types.h"
#include "proto/om.pb.h"



+ 0
- 155
parser/common/model_saver.cc View File

@@ -1,155 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <sys/stat.h>
#include <fcntl.h>

#include "parser/common/model_saver.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "common/util/error_manager/error_manager.h"
#include "mmpa/mmpa_api.h"

namespace {
const int kFileOpSuccess = 0;
} // namespace

namespace ge {
namespace parser {
const uint32_t kInteval = 2;

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path,
const Json &model) {
Status ret = SUCCESS;
if (file_path == nullptr || SUCCESS != CheckPath(file_path)) {
GELOGE(FAILED, "Check output file failed.");
return FAILED;
}
std::string model_str;
try {
model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()});
GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E19008");
GELOGE(FAILED, "Failed to convert JSON to string.");
return FAILED;
}

char real_path[PATH_MAX] = {0};
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!");
if (realpath(file_path, real_path) == nullptr) {
GELOGI("File %s does not exit, it will be created.", file_path);
}

// Open file
mode_t mode = S_IRUSR | S_IWUSR;
int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode);
if (fd == EN_ERROR || fd == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)});
GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno));
return FAILED;
}
const char *model_char = model_str.c_str();
uint32_t len = static_cast<uint32_t>(model_str.length());
// Write data to file
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len);
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19004", {"file", "errmsg"}, {file_path, strerror(errno)});
// Need to both print the error info of mmWrite and mmClose, so return ret after mmClose
GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno));
ret = FAILED;
}
// Close file
if (mmClose(fd) != EN_OK) {
GELOGE(FAILED, "Close file failed.");
ret = FAILED;
}
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) {
// Determine file path length
if (file_path.size() >= PATH_MAX) {
GELOGE(FAILED, "Path is too long:%zu", file_path.size());
return FAILED;
}

// Find the last separator
int path_split_pos = static_cast<int>(file_path.size() - 1);
for (; path_split_pos >= 0; path_split_pos--) {
if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') {
break;
}
}

if (path_split_pos == 0) {
return SUCCESS;
}

// If there is a path before the file name, create the path
if (path_split_pos != -1) {
if (CreateDirectory(std::string(file_path).substr(0, static_cast<size_t>(path_split_pos))) != kFileOpSuccess) {
GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str());
return FAILED;
}
}

return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) {
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length();
if (dir_path_len >= PATH_MAX) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX);
return -1;
}
char tmp_dir_path[PATH_MAX] = {0};
for (size_t i = 0; i < dir_path_len; i++) {
tmp_dir_path[i] = directory_path[i];
if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) {
if (access(tmp_dir_path, F_OK) != 0) {
int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.",
directory_path.c_str());
return ret;
}
}
}
}
}
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
return 0;
}

} // namespace parser
} // namespace ge

+ 0
- 55
parser/common/model_saver.h View File

@@ -1,55 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_FILE_SAVER_H_
#define PARSER_COMMON_FILE_SAVER_H_

#include <string>

#include "ge/ge_api_error_codes.h"
#include "register/register_types.h"
#include "nlohmann/json.hpp"

namespace ge {
namespace parser {
using Json = nlohmann::json;
using std::string;

class ModelSaver {
public:
/**
* @ingroup domi_common
* @brief Save JSON object to file
* @param [in] file_path File output path
* @param [in] model json object
* @return Status result
*/
static Status SaveJsonToFile(const char *file_path, const Json &model);

private:
///
/// @ingroup domi_common
/// @brief Check validity of the file path
/// @return Status result
///
static Status CheckPath(const string &file_path);

static int CreateDirectory(const std::string &directory_path);
};
} // namespace parser
} // namespace ge

#endif //PARSER_COMMON_FILE_SAVER_H_

+ 14
- 19
parser/common/module.mk View File

@@ -18,40 +18,35 @@ COMMON_LOCAL_SRC_FILES := \
register_tbe.cc \
parser_api.cc \
parser_inner_ctx.cc \
acl_graph_parser_util.cc\
proto_file_parser.cc \
acl_graph_parser_util.cc \
tbe_plugin_loader.cc \
model_saver.cc \
../../graph/passes/pass_manager.cc \
../../graph/common/omg_util.cc \
../tensorflow/tensorflow_custom_parser_adapter.cc \
../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \
../tensorflow/tensorflow_fusion_op_parser.cc \
../tensorflow/tensorflow_util.cc \
convert/pb2json.cc \
../../common/convert/pb2json.cc \
op_def/ir_pb_converter.cc \
op_def/defs.cc \
op_def/op_schema.cc \
op_def/operator.cc \
op_map.cc \
parser_types.cc \
pass_manager.cc \
parser_fp16_t.cc \
thread_pool.cc \
parser_utils.cc \

FMK_COMMON_SRC_FILES := \
# ../../common/fmk_error_codes.cc \
../../common/auth/cipher.cc \
../../common/context/ctx.cc \
../../graph/passes/pass_manager.cc \
../../graph/common/omg_util.cc \
../../common/types.cc \
../../common/auth/file_saver.cc \
../../common/util.cc \
../../common/model_saver.cc \
../../common/fmk_error_codes.cc \
../../common/fp16_t.cc \
../../common/thread_pool.cc \
../../common/auth/file_saver.cc \
../../common/auth/cipher.cc \
../../common/context/ctx.cc \

LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)
LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES)
#LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES)

LOCAL_C_INCLUDES := \
proto/om.proto \
@@ -73,10 +68,9 @@ LOCAL_C_INCLUDES := \
$(TOPDIR)inc/external/graph \
$(TOPDIR)inc/framework \
$(TOPDIR)inc/common/util \
$(TOPDIR)graphengine/ge \
$(TOPDIR)graphengine/ge/common \
$(TOPDIR)parser/parser \
$(TOPDIR)parser \
$(TOPDIR)framework/domi \
$(TOPDIR)framework/domi/common \
$(TOPDIR)framework/domi/parser \
$(TOPDIR)third_party/json/include \
$(TOPDIR)third_party/protobuf/include \
libc_sec/include \
@@ -90,6 +84,7 @@ LOCAL_SHARED_LIBRARIES := \
libc_sec \
liberror_manager \
libregister \
libge_common \

LOCAL_LDFLAGS := -lrt -ldl



+ 1
- 1
parser/common/op_def/constant_op.h View File

@@ -18,7 +18,7 @@
#ifndef DOMI_OP_CONSTANT_OP_H_
#define DOMI_OP_CONSTANT_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class ConstantOperator : public ParserOperator {


+ 2
- 2
parser/common/op_def/ir_pb_converter.cc View File

@@ -23,7 +23,7 @@
#include "graph/ge_tensor.h"
#include "graph/buffer.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"
#include "framework/common/util.h"

namespace ge {
@@ -98,7 +98,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co
GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str());
op_def->SetName(op.GetName());
op_def->SetType(op.GetType());
GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION));
GE_IF_BOOL_EXEC(op.GetType() == ge::YOLO, op_def->SetType(ge::REGION));

UpdateTensorForOpDesc(op, op_def);
GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(),


+ 1
- 1
parser/common/op_def/no_op_op.h View File

@@ -18,7 +18,7 @@
#ifndef DOMI_OP_NO_OP_OP_H_
#define DOMI_OP_NO_OP_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class NoOpOperator : public ParserOperator {


+ 1
- 1
parser/common/op_def/ref_switch_op.h View File

@@ -18,7 +18,7 @@
#ifndef DOMI_OP_REF_SWITCH_H_
#define DOMI_OP_REF_SWITCH_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class RefSwitchOperator : public ParserOperator {


+ 1
- 1
parser/common/op_def/shape_n_op.cc View File

@@ -17,7 +17,7 @@
// AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/shape_n_op.h"
#include "graph/debug/ge_attr_define.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {}


+ 1
- 1
parser/common/op_def/shape_n_op.h View File

@@ -18,7 +18,7 @@
#ifndef DOMI_OP_SHAPE_N_OP_H_
#define DOMI_OP_SHAPE_N_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class ShapeNOperator : public ParserOperator {


+ 1
- 1
parser/common/op_def/var_is_initialized_op_op.cc View File

@@ -20,7 +20,7 @@
#include <vector>

namespace ge {
VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {}
VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::VARISINITIALIZEDOP) {}

VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {}



+ 1
- 1
parser/common/op_def/var_is_initialized_op_op.h View File

@@ -18,7 +18,7 @@
#ifndef DOMI_OP_VARISINITIALIZEDOP_H_
#define DOMI_OP_VARISINITIALIZEDOP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class VarIsInitializedOpOperator : public ParserOperator {


+ 1
- 1
parser/common/op_def/variable_op.cc View File

@@ -19,7 +19,7 @@
#include "graph/debug/ge_attr_define.h"

namespace ge {
VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {}
VariableOperator::VariableOperator() : ParserOperator(ge::VARIABLE) {}

VariableOperator::~VariableOperator() {}



+ 1
- 1
parser/common/op_def/variable_op.h View File

@@ -19,7 +19,7 @@
#define DOMI_OP_VARIABLE_H_
#include <vector>
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"

namespace ge {
class VariableOperator : public ParserOperator {


+ 1
- 3
parser/common/op_map.cc View File

@@ -20,13 +20,12 @@
#include <string>
#include <vector>

#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"
#include "register/op_registry.h"

using std::map;
using std::string;
using std::vector;
using namespace ge::parser;

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> caffe_op_map = {
@@ -98,7 +97,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, string> tensorflow_
{"VarHandleOp", VARHANDLEOP},
{"VarIsInitializedOp", VARISINITIALIZEDOP},
{"IsVariableInitialized", ISVARIABLEINITIALIZED},
{"ReadVariableOp", READVARIABLEOP},
{"Reshape", RESHAPE},
{"Squeeze", SQUEEZE},
{"NoOp", NOOP},


+ 5
- 5
parser/common/op_parser_factory.h View File

@@ -23,8 +23,8 @@
#include <mutex>
#include <string>
#include <vector>
#include "parser/common/acl_graph_parser_util.h"
#include "framework/omg/parser/parser_types.h"
#include "common/ge/ge_util.h"
#include "common/types.h"
#include "framework/common/debug/ge_log.h"
#include "omg/omg_inner_types.h"
#include "external/register/register.h"
@@ -162,7 +162,7 @@ class CustomParserAdapterRegistrar {
*/
#define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Op_Parser() { \
std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
@@ -173,7 +173,7 @@ class CustomParserAdapterRegistrar {

#define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Fusion_Op_Parser() { \
std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
@@ -187,7 +187,7 @@ class CustomParserAdapterRegistrar {
/// @param [in] clazz CaffeCustomParserAdapter adaptation class
#define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_Op_Parser_Adapter() { \
std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \


+ 3
- 3
parser/common/parser_api.cc View File

@@ -17,7 +17,7 @@
#include "framework/omg/parser/parser_api.h"
#include "common/debug/log.h"

#include "tbe_plugin_loader.h"
#include "common/ge/tbe_plugin_manager.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/parser_inner_ctx.h"
@@ -36,7 +36,7 @@ Status ParserInitialize(const std::map<std::string, std::string> &options) {
}

// load custom op plugin
TBEPluginLoader::Instance().LoadPluginSo(options);
TBEPluginManager::Instance().LoadPluginSo(options);

std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas;
GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size());
@@ -67,7 +67,7 @@ Status ParserFinalize() {
return SUCCESS;
}

GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize());
GE_CHK_STATUS(TBEPluginManager::Instance().Finalize());
if (parser_initialized) {
parser_initialized = false;
}


+ 0
- 1270
parser/common/parser_fp16_t.cc
File diff suppressed because it is too large
View File


+ 0
- 653
parser/common/parser_fp16_t.h View File

@@ -1,653 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_FP16_T_H_
#define PARSER_COMMON_FP16_T_H_

#include <algorithm>
#include <cmath>
#include <cstdint>

namespace ge {
namespace parser {
using DimIndex = enum {
kDim0 = 0,
kDim1,
kDim2,
kDim3,
kDim4,
kDim5,
kDim6,
kDim7,
kDim8,
kDim9,
kDim10,
kDim11,
kDim12,
kDim13,
kDim14,
kDim15,
kDim16,
};

using BitShift = enum {
kBitShift2 = 2,
kBitShift3 = 3,
kBitShift4 = 4,
kBitShift5 = 5,
kBitShift6 = 6,
kBitShift7 = 7,
kBitShift8 = 8,
kBitShift9 = 9,
kBitShift10 = 10,
kBitShift11 = 11,
kBitShift12 = 12,
kBitShift13 = 13,
kBitShift14 = 14,
kBitShift15 = 15,
kBitShift16 = 16,
kBitShift20 = 20,
kBitShift24 = 24,
kBitShift27 = 27,
kBitShift28 = 28,
kBitShift31 = 31,
kBitShift32 = 32,
kBitShift36 = 36,
kBitShift40 = 40,
kBitShift44 = 44,
kBitShift48 = 48,
kBitShift52 = 52,
kBitShift56 = 56,
kBitShift59 = 59,
kBitShift60 = 60,
kBitShift63 = 63,
kBitShift64 = 64,
kBitShift128 = 128,
kBitShift255 = 255,
kBitShift256 = 256,
kBitShift512 = 512,
kBitShift768 = 768,
kBitShift784 = 784,
kBitShift1020 = 1020,
kBitShift1024 = 1024,
kBitShift3136 = 3136,
kBitShift4096 = 4096,
kBitShift6144 = 6144,
kBitShift10240 = 10240,
kBitShift65536 = 65536
};
/// @ingroup fp16 basic parameter
/// @brief fp16 exponent bias
constexpr uint16_t kFp16ExpBias = 15;
/// @ingroup fp16 basic parameter
/// @brief the exponent bit length of fp16 is 5
constexpr uint16_t kFp16ExpLen = 5;
/// @ingroup fp16 basic parameter
/// @brief the mantissa bit length of fp16 is 10
constexpr uint16_t kFp16ManLen = 10;
/// @ingroup fp16 basic parameter
/// @brief bit index of sign in fp16
constexpr uint16_t kFp16SignIndex = 15;
/// @ingroup fp16 basic parameter
/// @brief sign mask of fp16 (1 00000 00000 00000)
constexpr uint16_t kFp16SignMask = 0x8000;
/// @ingroup fp16 basic parameter
/// @brief exponent mask of fp16 ( 11111 00000 00000)
constexpr uint16_t kFp16ExpMask = 0x7C00;
/// @ingroup fp16 basic parameter
/// @brief mantissa mask of fp16 ( 11111 11111)
constexpr uint16_t kFp16ManMask = 0x03FF;
/// @ingroup fp16 basic parameter
/// @brief hide bit of mantissa of fp16( 1 00000 00000)
constexpr uint16_t kFp16ManHideBit = 0x0400;
/// @ingroup fp16 basic parameter
/// @brief maximum value (0111 1011 1111 1111)
constexpr uint16_t kFp16Max = 0x7BFF;
/// @ingroup fp16 basic parameter
/// @brief minimum value (1111 1011 1111 1111)
constexpr uint16_t kFp16Min = 0xFBFF;
/// @ingroup fp16 basic parameter
/// @brief absolute maximum value (0111 1111 1111 1111)
constexpr uint16_t kFp16AbsMax = 0x7FFF;
/// @ingroup fp16 basic parameter
/// @brief maximum exponent value of fp16 is 15(11111)
constexpr uint16_t kFp16MaxExp = 0x001F;
/// @ingroup fp16 basic parameter
/// @brief maximum valid exponent value of fp16 is 14(11110)
constexpr uint16_t kFp16MaxValidExp = 0x001E;
/// @ingroup fp16 basic parameter
/// @brief maximum mantissa value of fp16(11111 11111)
constexpr uint16_t kFp16MaxMan = 0x03FF;
/// @ingroup fp16 basic parameter
/// @brief absolute minimum normal value of fp16
/// (E=1,M=0 D=2^(-14)=0.00006103515625)
constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14);
/// @ingroup fp16 basic operator
/// @brief get sign of fp16
#define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1)
/// @ingroup fp16 basic operator
/// @brief get exponent of fp16
#define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp)
/// @ingroup fp16 basic operator
/// @brief get mantissa of fp16
#define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400))
/// @ingroup fp16 basic operator
/// @brief constructor of fp16 from sign exponent and mantissa
#define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is zero
#define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0)
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is a denormalized value
#define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is infinite
#define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask)
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is NaN
#define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is invalid
#define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask)
/// @ingroup fp32 basic parameter
/// @brief fp32 exponent bias
constexpr uint16_t kFp32ExpBias = 127;
/// @ingroup fp32 basic parameter
/// @brief the exponent bit length of float/fp32 is 8
constexpr uint16_t kFp32ExpLen = 8;
/// @ingroup fp32 basic parameter
/// @brief the mantissa bit length of float/fp32 is 23
constexpr uint16_t kFp32ManLen = 23;
/// @ingroup fp32 basic parameter
/// @brief bit index of sign in float/fp32
constexpr uint16_t kFp32SignIndex = 31;
/// @ingroup fp32 basic parameter
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000)
constexpr uint32_t kFp32SignMask = 0x80000000u;
/// @ingroup fp32 basic parameter
/// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ExpMask = 0x7F800000u;
/// @ingroup fp32 basic parameter
/// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111)
constexpr uint32_t kFp32ManMask = 0x007FFFFFu;
/// @ingroup fp32 basic parameter
/// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ManHideBit = 0x00800000u;
/// @ingroup fp32 basic parameter
/// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu;
/// @ingroup fp32 basic parameter
/// @brief maximum exponent value of fp32 is 255(1111 1111)
constexpr uint32_t kFp32MaxExp = 0xFF;
/// @ingroup fp32 basic parameter
/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32MaxMan = 0x7FFFFF;
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is NaN
#define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask))
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is infinite
#define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask)))
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is a denormalized value
#define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0))
/// @ingroup fp32 basic operator
/// @brief get sign of fp32
#define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1)
/// @ingroup fp32 basic operator
/// @brief get exponent of fp16
#define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen)
/// @ingroup fp32 basic operator
/// @brief get mantissa of fp16
#define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit))
/// @ingroup fp32 basic operator
/// @brief constructor of fp32 from sign exponent and mantissa
#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan))
/// @ingroup fp64 basic parameter
/// @brief fp64 exponent bias
constexpr uint16_t kFp64ExpBias = 1023;
/// @ingroup fp64 basic parameter
/// @brief the exponent bit length of double/fp64 is 11
constexpr uint16_t kFp64ExpLen = 11;
/// @ingroup fp64 basic parameter
/// @brief the mantissa bit length of double/fp64 is 52
constexpr uint16_t kFp64ManLen = 52;
/// @ingroup fp64 basic parameter
/// @brief bit index of sign in double/fp64 is 63
constexpr uint16_t kFp64SignIndex = 63;
/// @ingroup fp64 basic parameter
/// @brief sign mask of fp64 (1 000 (total 63bits 0))
constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1))
constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu;
/// @ingroup fp64 basic parameter
/// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief absolute maximum value (0 111?-?-(total 63bits 1))
constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu;
/// @ingroup fp64 basic parameter
/// @brief maximum exponent value of fp64 is 2047(1 11111 11111)
constexpr uint64_t kFp64MaxExp = 0x07FF;
/// @ingroup fp64 basic parameter
/// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1))
constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu;
/// @ingroup fp64 special value judgment
/// @brief whether a fp64 is NaN
#define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask))
/// @ingroup fp64 special value judgment
/// @brief whether a fp64 is infinite
#define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask)))
/// @ingroup integer special value judgment
/// @brief maximum positive value of int8_t (0111 1111)
constexpr int8_t kInt8Max = 0x7F;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 8 bits length (1111 111)
constexpr uint8_t kBitLen8Max = 0xFF;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int16_t (0111 1111 1111 1111)
constexpr int16_t kInt16Max = 0x7FFF;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111)
constexpr uint16_t kBitLen16Max = 0xFFFF;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111)
constexpr int32_t kInt32Max = 0x7FFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int64_t
/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 64 bits length
/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu;

/// @ingroup fp16_t enum
/// @brief round mode of last valid digital
enum TagFp16RoundMode {
kRoundToNearest = 0, // < round to nearest even
kRoundByTruncated, // < round by truncated
kRoundModeReserved,
};

/// @ingroup fp16_t
/// @brief Half precision float
/// bit15: 1 bit SIGN +---+-----+------------+
/// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM|
/// bit0-9: 10bit MAN +---+-----+------------+
using fp16_t = struct TagFp16 {
uint16_t val;

public:
/// @ingroup fp16_t constructor
/// @brief Constructor without any param(default constructor)
TagFp16(void) { val = 0x0u; }

/// @ingroup fp16_t constructor
/// @brief Constructor with an uint16_t value
TagFp16(const uint16_t &ui_val) : val(ui_val) {}

/// @ingroup fp16_t constructor
/// @brief Constructor with a fp16_t object(copy constructor)
TagFp16(const TagFp16 &fp) : val(fp.val) {}

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be added
/// @brief Override addition operator to performing fp16_t addition
/// @return Return fp16_t result of adding this and fp
TagFp16 operator+(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be subtracted
/// @brief Override addition operator to performing fp16_t subtraction
/// @return Return fp16_t result of subtraction fp from this
TagFp16 operator-(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be multiplied
/// @brief Override multiplication operator to performing fp16_t multiplication
/// @return Return fp16_t result of multiplying this and fp
TagFp16 operator*(const TagFp16 fp);

/// @ingroup fp16_t math operator divided
/// @param [in] fp fp16_t object to be divided
/// @brief Override division operator to performing fp16_t division
/// @return Return fp16_t result of division this by fp
TagFp16 operator/(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be added
/// @brief Override addition operator to performing fp16_t addition
/// @return Return fp16_t result of adding this and fp
TagFp16 operator+=(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be subtracted
/// @brief Override addition operator to performing fp16_t subtraction
/// @return Return fp16_t result of subtraction fp from this
TagFp16 operator-=(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be multiplied
/// @brief Override multiplication operator to performing fp16_t multiplication
/// @return Return fp16_t result of multiplying this and fp
TagFp16 operator*=(const TagFp16 fp);

/// @ingroup fp16_t math operator divided
/// @param [in] fp fp16_t object to be divided
/// @brief Override division operator to performing fp16_t division
/// @return Return fp16_t result of division this by fp
TagFp16 operator/=(const TagFp16 fp);

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t if-equal comparison
/// @return Return boolean result of if-equal comparison of this and fp.
bool operator==(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t not-equal comparison
/// @return Return boolean result of not-equal comparison of this and fp.
bool operator!=(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-than comparison
/// @return Return boolean result of greater-than comparison of this and fp.
bool operator>(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-equal comparison
/// @return Return boolean result of greater-equal comparison of this and fp.
bool operator>=(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t less-than comparison
/// @return Return boolean result of less-than comparison of this and fp.
bool operator<(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t less-equal comparison
/// @return Return boolean result of less-equal comparison of this and fp.
bool operator<=(const TagFp16 &fp) const;

/// @ingroup fp16_t math evaluation operator
/// @param [in] fp fp16_t object to be copy to fp16_t
/// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t
/// @return Return fp16_t result from fp
TagFp16 &operator=(const TagFp16 &fp);

/// @ingroup fp16_t math evaluation operator
/// @param [in] f_val float object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert float to fp16_t
/// @return Return fp16_t result from f_val
TagFp16 &operator=(const float &f_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] d_val double object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert double to fp16_t
/// @return Return fp16_t result from d_val
TagFp16 &operator=(const double &d_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val float object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert float to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int8_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint8_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint8_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint8_t &ui_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val int16_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert int16_t to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int16_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint16_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint16_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint16_t &ui_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val int32_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert int32_t to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int32_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint32_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint32_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint32_t &ui_val);

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
operator float() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
operator double() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
operator int8_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
operator uint8_t() const;

/// @ingroup fp16_t conversion
/// @brief Override convert operator to convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
operator int16_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
operator uint16_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
operator int32_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
operator uint32_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int64_t
/// @return Return int64_t value of fp16_t
operator int64_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint64_t
/// @return Return uint64_t value of fp16_t
operator uint64_t() const;

/// @ingroup fp16_t judgment method
/// @param [in] fp fp16_t object to be judgement
/// @brief whether a fp16_t is inifinite
/// @return Returns 1:+INF -1:-INF 0:not INF
int IsInf();

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
float ToFloat() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
double ToDouble() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
int8_t ToInt8() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
uint8_t ToUInt8() const;

/// @ingroup fp16_t conversion
/// @brief Convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
int16_t ToInt16() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
uint16_t ToUInt16() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
int32_t ToInt32() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
uint32_t ToUInt32() const;
};

/// @ingroup fp16_t public method
/// @param [in] val signature is negative
/// @param [in|out] s sign of fp16_t object
/// @param [in|out] e exponent of fp16_t object
/// @param [in|out] m mantissa of fp16_t object
/// @brief Extract the sign, exponent and mantissa of a fp16_t object
void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m);

/// @ingroup fp16_t public method
/// @param [in] negative sign is negative
/// @param [in|out] man mantissa to be reverse
/// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement)
/// @return Return complement of man
template<typename T>
void ReverseMan(bool negative, T &man) {
if (negative) {
man = (~(man)) + 1;
}
}

/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one fp16_t/float number
/// @param [in] m_a mantissa of one fp16_t/float number
/// @param [in] e_b exponent of another fp16_t/float number
/// @param [in] m_b mantissa of another fp16_t/float number
/// @brief choose mantissa to be shift right whoes exponent is less than another one
/// @return Return mantissawhoes exponent is less than another one
template<typename T>
T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) {
return (e_a > e_b) ? m_b : m_a;
}

/// @ingroup fp16_t public method
/// @param [in] man mantissa to be operate
/// @param [in] shift right shift bits
/// @brief right shift a mantissa
/// @return Return right-shift mantissa
template<typename T>
T RightShift(T man, int16_t shift) {
int bits = sizeof(T) * 8; // one byte have 8 bits
T mask = (((T) 1u) << ((unsigned int) (bits - 1)));
for (int i = 0; i < shift; i++) {
man = ((man & mask) | (man >> 1));
}
return man;
}

/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one temp fp16_t number
/// @param [in] m_a mantissa of one temp fp16_t number
/// @param [in] e_b exponent of another temp fp16_t number
/// @param [in] m_b mantissa of another temp fp16_t number
/// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t
/// @return Return mantissa sum
template<typename T>
T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) {
T sum = 0;
if (e_a != e_b) {
T m_tmp = 0;
int16_t e_tmp = std::abs(e_a - e_b);
if (e_a > e_b) {
m_tmp = m_b;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_a + m_tmp;
} else {
m_tmp = m_a;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_tmp + m_b;
}
} else {
sum = m_a + m_b;
}
return sum;
}

/// @ingroup fp16_t public method
/// @param [in] bit0 whether the last preserved bit is 1 before round
/// @param [in] bit1 whether the abbreviation's highest bit is 1
/// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0
/// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t
/// @param [in] shift abbreviation bits
/// @brief Round fp16_t or float mantissa to nearest value
/// @return Returns true if round 1,otherwise false;
template<typename T>
T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) {
man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0);
return man;
}

/// @ingroup fp16_t public method
/// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t
/// @brief Get bit length of a uint32_t number
/// @return Return bit length of man
template<typename T>
int16_t GetManBitLength(T man) {
int16_t len = 0;
while (man) {
man >>= 1;
len++;
}
return len;
}
} // namespace parser
} // namespace ge
#endif // GE_PARSER_COMMON_FP16_T_H_

+ 0
- 494
parser/common/parser_types.cc View File

@@ -1,494 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "framework/omg/parser/parser_types.h"


namespace ge{
namespace parser {
const char *DATA = "Data";
const char *AIPPDATA = "AippData";
const char *CONVOLUTION = "Convolution";
const char *CORRELATION = "Correlation";
const char *CORRELATIONV2 = "Correlation_V2";
const char *DECONVOLUTION = "Deconvolution";
const char *POOLING = "Pooling";
const char *ELTWISE = "Eltwise";
const char *RELU = "ReLU";
const char *RELU6 = "ReLU6";
const char *SIGMOID = "Sigmoid";
const char *ABSVAL = "AbsVal";
const char *TANH = "TanH";
const char *PRELU = "PReLU";
const char *BATCHNORM = "BatchNorm";
const char *FUSIONBATCHNORM = "FusionBatchNorm";
const char *SCALE = "Scale";
const char *FULL_CONNECTION = "FullConnection";
const char *SOFTMAX = "Softmax";
const char *PLUS = "Plus";
const char *ACTIVATION = "Activation";
const char *FLATTEN = "Flatten";
const char *ADD = "Add";
const char *SUB = "Sub";
const char *MUL = "Mul";
const char *MATMUL = "MatMul";
const char *RSQRT = "Rsqrt";
const char *BIASADD = "BiasAdd";
const char *RESHAPE = "Reshape";
const char *REFORMAT = "ReFormat";
const char *DEPCONVOLUTION = "ConvolutionDepthwise";
const char *DROPOUT = "Dropout";
const char *DROPOUTGENMASK = "DropOutGenMask";
const char *DROPOUTDOMASK = "DropOutDoMask";
const char *CONCAT = "Concat";
const char *ROIPOOLING = "ROIPooling";
const char *PROPOSAL = "Proposal";
const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput";
const char *DETECTIONPOSTPROCESS = "Detectpostprocess";
const char *LRN = "LRN";
const char *TRANSDATA = "TransData";
const char *PERMUTE = "Permute";
const char *SSDNORMALIZE = "SSDNormalize";
const char *SSDPRIORBOX = "SSDPriorBox";
const char *NETOUTPUT = "NetOutput";
const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput";
const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput";
const char *CHANNELAXPY = "ChannelAxpy";
const char *PSROIPOOLING = "PSROIPooling";
const char *POWER = "Power";
const char *POW = "Pow";
const char *ROIALIGN = "ROIAlign";
const char *PYTHON = "Python";
const char *FREESPACEEXTRACT = "FreespaceExtract";
const char *SPATIALTF = "SpatialTransform";
const char *SHAPE = "Shape";
const char *SHAPEN = "ShapeN";
const char *ARGMAX = "ArgMax";
const char *GATHERND = "GatherNd";
const char *GATHER = "Gather";
const char *REALDIV = "RealDiv";
const char *PACK = "Pack";
const char *SLICE = "Slice";
const char *SLICED = "SliceD";
const char *FLOORDIV = "FloorDiv";
const char *SQUEEZE = "Squeeze";
const char *UNSQUEEZE = "Unsqueeze";
const char *STRIDEDSLICE = "StridedSlice";
const char *RANGE = "Range";
const char *RPNPROPOSALS = "RpnProposals";
const char *DECODEBBOX = "DecodeBbox";
const char *PAD = "Pad";
const char *PADV2 = "PadV2";
const char *MIRRORPAD = "MirrorPad";
const char *TILE = "Tile";
const char *SIZE = "Size";
const char *CLIPBOXES = "ClipBoxes";
const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions";
const char *SPLIT = "Split";
const char *SPLITV = "SplitV";
const char *EXPANDDIMS = "ExpandDims";
const char *EMPTY = "Empty";
const char *MEAN = "Mean";
const char *GREATER = "Greater";
const char *SWITCH = "Switch";
const char *SWITCHN = "SwitchN";
const char *MERGE = "Merge";
const char *SYMBOLICGRADIENT = "SymbolicGradient";
const char *REMOTECALL = "RemoteCall";
const char *_IF = "_If";
const char *STATELESSIF = "StatelessIf";
const char *IF = "If";
const char *CASE = "Case";
const char *_WHILE = "_While";
const char *WHILE = "While";
const char *STATELESSWHILE = "StatelessWhile";
const char *FOR = "For";
const char *PARTITIONEDCALL = "PartitionedCall";
const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall";
const char *FAKEPARAM = "FakeParam";
const char *TRANSPOSE = "Transpose";
const char *TRANSPOSED = "TransposeD";
const char *CAST = "Cast";
const char *REGION = "Region";
const char *YOLO = "Yolo";
const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput";
const char *FILL = "Fill";
const char *REVERSE = "Reverse";
const char *UNPACK = "Unpack";
const char *YOLO2REORG = "Yolo2Reorg";
const char *REDUCESUM = "ReduceSum";
const char *SUM = "Sum";
const char *CONSTANT = "Const";
const char *RESIZEBILINEAR = "ResizeBilinear";
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char *MAXIMUM = "Maximum";
const char *FRAMEWORKOP = "FrameworkOp";
const char *ARG = "_Arg";
const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad";
const char *LSTM = "LSTM";
const char *HIGHWAY = "HighWay";
const char *RNN = "RNN";
const char *ATTENTIONDECODER = "AttentionDecoder";
const char *LOGICAL_NOT = "LogicalNot";
const char *LOGICAL_AND = "LogicalAnd";
const char *LOGICAL_OR = "LogicalOr";
const char *EQUAL = "Equal";
const char *NOTEQUAL = "NotEqual";
const char *INTERP = "Interp";
const char *SHUFFLECHANNEL = "ShuffleChannel";
const char *AIPP = "Aipp";
const char *MULTISHAPE = "MultiShape";
const char *RECIPROCAL = "Reciprocal";
const char *SELU = "Selu";
const char *ELU = "Elu";
const char *ACOSH = "Acosh";
const char *ASINH = "Asinh";
const char *MINIMUM = "Minimum";
const char *CLIP = "Clip";
const char *L2NORMALIZE = "L2Normalize";
const char *CROPANDRESIZE = "CropAndResize";
const char *UNUSEDCONST = "UnusedConst";
const char *SPARSETODENSE = "SparseToDense";
const char *NONMAXSUPPRESSION = "NonMaxSuppression";
const char *TOPKV2 = "TopKV2";
const char *INVERTPERMUTATION = "InvertPermutation";
const char *MULTINOMIAL = "Multinomial";
const char *REVERSESEQUENCE = "ReverseSequence";
const char *REDUCEPROD = "ReduceProd";
const char *REDUCEMAX = "ReduceMax";
const char *REDUCEMIN = "ReduceMin";
const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches";
const char *SQRT = "Sqrt";
const char *REDUCEALL = "ReduceAll";
const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor";
const char *SPACETOBATCHND = "SpaceToBatchND";
const char *BATCHTOSPACEND = "BatchToSpaceND";
const char *ASSERT = "Assert";
const char *GREATEREQUAL = "GreaterEqual";
const char *FLOOR = "Floor";
const char *RANDOMUNIFORM = "RandomUniform";
const char *BATCHMATMUL = "BatchMatMul";
const char *SPACETODEPTH = "SpaceToDepth";
const char *DEPTHTOSPACE = "DepthToSpace";
const char *RINT = "Rint";
const char *ATAN = "Atan";
const char *ATAN2 = "Atan2";
const char *ATANH = "Atanh";
const char *ACOS = "Acos";
const char *ASIN = "Asin";
const char *NEG = "Neg";
const char *LOG = "Log";
const char *TAN = "Tan";
const char *ROUND = "Round";
const char *UPSAMPLE = "Upsample";
const char *FLOORMOD = "FloorMod";
const char *LESS = "Less";
const char *LESSEQUAL = "LessEqual";
const char *ONEHOT = "OneHot";
const char *REFSWITCH = "RefSwitch";
const char *REFMERGE = "RefMerge";
const char *ENTER = "Enter";
const char *REFENTER = "RefEnter";
const char *LOOPCOND = "LoopCond";
const char *NEXTITERATION = "NextIteration";
const char *REFNEXTITERATION = "RefNextIteration";
const char *EXIT = "Exit";
const char *REFEXIT = "RefExit";
const char *CONTROLTRIGGER = "ControlTrigger";
const char *ZEROSLIKE = "ZerosLike";
const char *EXP = "Exp";
const char *WHERE = "Where";
const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars";
const char *SOFTPLUS = "Softplus";
const char *SOFTSIGN = "Softsign";
const char *COSH = "Cosh";
const char *SINH = "Sinh";
const char *SQUAREDDIFFERENCE = "SquaredDifference";
const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion
const char *SSDPOSTPROCESSOR = "SSDPostProcessor";
const char *RETINANETBOXES = "RetinanetBoxes";
const char *RETINAMULTIANCHORS = "RetinaMultiAnchor";
const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes";
const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections";
const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor";
const char *RETINANETANCHORS = "RetinanetAnchors";
const char *FASTERRCNNMAP = "FasterRCNNMap";
const char *FASTERRCNNMAP1 = "FasterRCNNMap1";
const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor";
const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling";
const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor";
const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator";
const char *ROIINTERPOOLING = "ROIInterPooling";
const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow";
const char *EMBEDLOOKUP = "EmbedLookup";
const char *HASHLOOKUP = "HashLookup";
const char *LSH_PROJ = "LshProject";
const char *SVDF = "SVDF";
const char *SSDANCHORGENERATOR = "SSDAnchorGenerator";
const char *IDENTITY = "Identity";
const char *IDENTITYN = "IdentityN";
const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault";
const char *SELECT = "Select";
const char *GETSPAN = "GetSpan";
const char *STOPGRADIENT = "StopGradient";
const char *PREVENTGRADIENT = "PreventGradient";
const char *GUARANTEECONST = "GuaranteeConst";
const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs";
const char *BROADCASTARGS = "BroadcastArgs";
const char *CONFUSIONMATRIX = "ConfusionMatrix";
const char *RANK = "Rank";
const char *PLACEHOLDER = "PlaceHolder";
const char *END = "End";
const char *BASICLSTMCELL = "BasicLSTMCell";
const char *GETNEXT = "GetNext";
const char *INITDATA = "InitData";
const char *REFIDENTITY = "RefIdentity";
const char *BITCAST = "Bitcast";

/***************Ann special operator*************************/
const char *ANN_MEAN = "AnnMean";
const char *ANN_CONVOLUTION = "AnnConvolution";
const char *ANN_DEPCONVOLUTION = "AnnDepthConv";
const char *ANN_FULLCONNECTION = "AnnFullConnection";
const char *ANN_NETOUTPUT = "AnnNetOutput";
const char *ANN_DATA = "AnnData";
const char *ANN_RESHAPE = "AnnReshape";
const char *ANN_ADD = "AnnAdd";
const char *ANN_MUL = "AnnMul";
const char *ANN_SUB = "AnnSub";
const char *ANN_DIV = "AnnDiv";
const char *ANN_DEQUANTIZE = "AnnDequant";
const char *ANN_QUANTIZE = "AnnQuant";
const char *ANN_PAD = "AnnPad";
const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear";

/***************************************************/
/******************Training operator*************************/
const char *GATHERV2 = "GatherV2";
const char *CONVGRADFILTER = "Conv2DBackpropFilter";
const char *CONV2D = "Conv2D";
const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput";
const char *FUSEDBATCHNORM = "FusedBatchNorm";
const char *BIASADDGRAD = "BiasAddGrad";
const char *ACTIVATIONGRAD = "ReluGrad";
const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax";
const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax";
const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits";
const char *SNAPSHOT = "Snapshot";
const char *VAR = "Var";
const char *MEANGRAD = "MeanGrad";
const char *TRANSLATE = "Translate";
const char *ADDN = "AddN";
const char *L2LOSS = "L2Loss";
const char *MULTIPLY = "Multiply";
const char *HUBERLOSSGRAD = "HuberLossGrad";
const char *HUBERLOSS = "HuberLoss";
const char *NEGATIVE = "Negative";
const char *SSDCAST = "SSDCast";
const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy";
const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad";
const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion";
const char *CONCATFOUR2FIVE = "ConcatFour2Five";
const char *CONCATFIVE2FOUR = "ConcatFive2Four";
const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul";
const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean";

const char *VARIABLEV2 = "VariableV2";
const char *VARHANDLEOP = "VarHandleOp";
const char *TEMPORARYVARIABLE = "TemporaryVariable";
const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable";
const char *VARIABLE = "Variable";
const char *ASSIGN = "Assign";
const char *ASSIGNVARIABLEOP = "AssignVariableOp";
const char *ASSIGNADD = "AssignAdd";
const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp";
const char *ASSIGNSUB = "AssignSub";
const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp";
const char *APPLYMOMENTUM = "ApplyMomentum";
const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum";
const char *SGD = "SGD";
const char *NOOP = "NoOp";
const char *READVARIABLEOP = "ReadVariableOp";
const char *PARALLELCONCATSTART = "_ParallelConcatStart";
const char *CONSTANTOP = "Constant";
const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter";
const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput";
const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative";
const char *DROPOUTGRAD = "DropOutGrad";
const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision";
const char *APPLYRMSPROP = "ApplyRMSProp";
const char *RELU6GRAD = "Relu6Grad";
const char *AVGPOOLGRAD = "AvgPoolGrad";
const char *CONCATV2 = "ConcatV2";
const char *CONCATOFFSET = "ConcatOffset";
const char *LAYERNORMGRAD = "LayerNormGrad";
const char *LAYERNORM = "LayerNorm";
const char *LARS = "Lars";
const char *DYNAMICSTITCH = "DynamicStitch";

/***************************************************/
const char *SQUARE = "Square";
const char *HCOMBROADCAST = "HcomBroadcast";
const char *HCOMALLGATHER = "HcomAllGather";
const char *HCOMALLREDUCE = "HcomAllReduce";
const char *HCOMREDUCESCATTER = "HcomReduceScatter";
const char *HCOMSEND = "HcomSend";
const char *HCOMRECEIVE = "HcomReceive";
const char *HCOMREMOTEREAD = "HcomRemoteRead";
const char *HCOMREMOTEWRITE = "HcomRemoteWrite";

const char *VARASSIGN = "VarAssign";
const char *VARISINITIALIZEDOP = "VarIsInitializedOp";
const char *LogTimeStamp = "LogTimeStamp";
const char *ISVARIABLEINITIALIZED = "IsVariableInitialized";
const char *STREAMSWITCH = "StreamSwitch";
const char *STREAMSWITCHN = "StreamSwitchN";
const char *STREAMACTIVE = "StreamActive";
const char *MEMCPYASYNC = "MemcpyAsync";
const char *MEMCPYADDRASYNC = "MemcpyAddrAsync";
const char *STREAMMERGE = "StreamMerge";
const char *ENDGRAPH = "EndGraph";
const char *SEND = "Send";
const char *RECV = "Recv";
const char *ENDOFSEQUENCE = "EndOfSequence";

const char *LABELSET = "LabelSet";
const char *LABELGOTO = "LabelGoto";
const char *LABELGOTOEX = "LabelGotoEx";
const char *LABELSWITCH = "LabelSwitch";
const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex";

const char *ATOMICADDRCLEAN = "AtomicAddrClean";

const char *ABS_GRAD = "AbsGrad";
const char *ACCUMULATE_N_V2 = "AccumulateNV2";
const char *ACOS_GRAD = "AcosGrad";
const char *ACOSH_GRAD = "AcoshGrad";
const char *ANY = "Any";
const char *APPROXIMATE_EQUAL = "ApproximateEqual";
const char *ASIN_GRAD = "AsinGrad";
const char *ASINH_GRAD = "AsinhGrad";
const char *ATAN_GRAD = "AtanGrad";
const char *BROADCAST_TO = "BroadcastTo";
const char *ELU_GRAD = "EluGrad";
const char *ADD_V2 = "AddV2";
const char *DATAFORMATDIMMAP = "DataFormatDimMap";
const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute";
const char *BESSELI0E = "BesselI0e";
const char *BESSELI1E = "BesselI1e";
const char *APPLYADADELTA = "ApplyAdadelta";
const char *APPLYADAGRAD = "ApplyAdagrad";
const char *APPLYADAGRADDA = "ApplyAdagradDA";
const char *APPLYADAM = "ApplyAdam";
const char *APPLYADAMAX = "ApplyAdaMax";
const char *APPLYADDSIGN = "ApplyAddSign";
const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp";
const char *APPLYFTRL = "ApplyFtrl";
const char *APPLYFTRLV2 = "ApplyFtrlV2";
const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent";
const char *APPLYPOWERSIGN = "ApplyPowerSign";
const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad";
const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent";
const char *DEQUANTIZE = "Dequantize";

const char *FOCAL_LOSS = "FocalLoss";
const char *FOCAL_LOSS_GRAD = "FocalLossGrad";
const char *SMOOTHL1_LOSS = "SmoothL1Loss";
const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad";
const char *REDUCEMEAN = "ReduceMean";
const char *CONCAT_V2 = "ConcatV2";
const char *ONEHOT_V2 = "OneHotV2";
const char *SLICE_V2 = "SliceV2";
const char *TILE_V2 = "TileV2";
const char *SUM_V2 = "SumV2";
// Common type when the operator has the same name
const char *DETECTIONOUTPUT = "DetectionOutput";
// Custom operator
const char *CUSTOMOP = "CustomOp";
const char *CUSTOMOP_NCHW = "CustomOpNchw";
const char *CUSTOMOP_NHWC = "CustomOpNhwc";
const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0";

// Depthwise 4d_2_6d,6d_2_4d
const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d";
const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d";

const char *SQRTGRAD = "SqrtGrad";
const char *SIGMOIDGRAD = "SigmoidGrad";

const char *TRANSSHAPE = "TransShape";

// Horovod operator
const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce";
const char *HVDCALLBACKALLGATHER = "HorovodAllgather";
const char *HVDCALLBACKBROADCAST = "HorovodBroadcast";
const char *HVDWAIT = "HorovodWait";

///
/// @brief Magic number of model file
///
const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number

///
/// @brief Model head length
///
const uint32_t MODEL_FILE_HEAD_LEN = 256;

const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0///

///
/// @ingroup domi_omg
/// @brief alpha default value
///
const float ALPHA_DEFAULT_VALUE = 1.0;

///
/// @ingroup domi_omg
/// @brief beta default value
///
const float BETA_DEFAULT_VALUE = 0.0;

///
/// @ingroup domi_omg
/// @brief Input node type
///
const std::string INPUT_TYPE = "Input";
const std::string DUMMY_DATA = "DummyData";

// for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";

const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";

///
/// @ingroup domi_omg
/// @brief DATA node type
///
const std::string DATA_TYPE = "Data";

///
/// @ingroup domi_omg
/// @brief Frame operator type
///
const std::string FRAMEWORK_OP_TYPE = "FrameworkOp";

///
/// @ingroup domi_omg
/// @brief Convolution node type
///
const std::string NODE_NAME_NET_OUTPUT = "Node_Output";
} // namespace parser
} // namespace ge

+ 221
- 0
parser/common/parser_utils.cc View File

@@ -0,0 +1,221 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser_utils.h"
#include "external/ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "graph/anchor.h"
#include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "register/op_registry.h"

namespace ge {
namespace {
Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(new_node);
if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str());
return FAILED;
}
auto op_desc = new_node->GetOpDesc();
static std::atomic_long new_node_index(0);
auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++);
op_desc->SetName(new_name);
bool ret = ge::AttrUtils::SetListStr(op_desc,
ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES,
std::move(std::vector<std::string>{node->GetName()}));
if (!ret) {
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str());
}
GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str());
return SUCCESS;
}
}

Status ParserUtils::ExpandOneToManyGraph(Graph &graph) {
GELOGD("Begin run ParserUtils::ExpandOneToManyGraph.");
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
for (const auto &n : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(n);
std::string ori_type;
(void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type);
domi::ParseOpToGraphFunc parse_op_to_graph_func =
domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type);
if (parse_op_to_graph_func == nullptr) {
GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.",
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
continue;
}
GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.",
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
Graph subgraph("one_to_many_graph");
Operator op = OpDescUtils::CreateOperatorFromNode(n);
Status ret = parse_op_to_graph_func(op, subgraph);
if (ret != SUCCESS) {
GELOGE(FAILED, "Get one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
}
ret = ExpandNodeToSubgraph(subgraph, n, graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "Expand one to many graph failed for op:%s.", op.GetName().c_str());
return FAILED;
}
}
GELOGD("run ParserUtils::ExpandOneToManyGraph success.");
return SUCCESS;
}

Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) {
ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph);
GE_CHECK_NOTNULL(sub_compute_graph);
ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);

// add subgraph node to graph.
std::vector<NodePtr> input_nodes;
for (const auto &n : sub_compute_graph->GetDirectNode()) {
auto new_node = compute_graph->AddNode(n);
GE_CHECK_NOTNULL(new_node);
if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) {
GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str());
return FAILED;
}

if (new_node->GetType() == "Data") {
input_nodes.emplace_back(new_node);
}
}

// handle input context.
Status ret = HandleInputContext(node, input_nodes, compute_graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "run ParserUtils::HandleInputContext failed.");
return FAILED;
}

// handle output context.
std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo();
ret = HandleOutputContext(node, out_node_index);
if (ret != SUCCESS) {
GELOGE(FAILED, "run ParserUtils::HandleOutputContext failed.");
return FAILED;
}

graphStatus graph_status = GraphUtils::RemoveNodeWithoutRelink(compute_graph, node);
if (graph_status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove node:%s failed.", node->GetName().c_str());
return FAILED;
}
graph_status = compute_graph->TopologicalSorting();
if (graph_status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Topological sorting failed.");
return FAILED;
}
return SUCCESS;
}

Status ParserUtils::HandleInputContext(const NodePtr &node,
const std::vector<NodePtr> &input_nodes,
const ComputeGraphPtr &compute_graph) {
GE_CHECK_NOTNULL(node);
for (const auto &in_n : input_nodes) {
GE_CHECK_NOTNULL(in_n);
int index;
if (!AttrUtils::GetInt(in_n->GetOpDesc(), ATTR_NAME_INDEX, index)) {
GELOGE(FAILED, "Get attr index of node:%s failed.", in_n->GetName().c_str());
return FAILED;
}
GELOGD("Begin to handle input node:%s with index:%d.", in_n->GetName().c_str(), index);
// get node's in data anchor and peer out anchor
auto node_in_anchor = node->GetInDataAnchor(index);
GE_CHECK_NOTNULL(node_in_anchor);
auto src_out_anchor = node_in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(src_out_anchor);
auto data_out_anchor = in_n->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(data_out_anchor);
for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) {
// add data edge
graphStatus ret = GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove data out anchor and peer in anchor failed.");
return FAILED;
}
ret = GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node in anchor and peer out anchor failed.");
return FAILED;
}
ret = GraphUtils::AddEdge(src_out_anchor, peer_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "link node's peer out anchor and data's peer in anchor failed.");
return FAILED;
}

// add control edge
if (node->GetInControlAnchor() != nullptr) {
for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) {
graphStatus ret = GraphUtils::AddEdge(out_anchor, peer_in_anchor->GetOwnerNode()->GetInControlAnchor());
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "add control edge failed.");
return FAILED;
}
}
}
}
graphStatus ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, in_n);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node:%s failed.", in_n->GetName().c_str());
return FAILED;
}
}
return SUCCESS;
}

Status ParserUtils::HandleOutputContext(const NodePtr &node,
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index) {
GE_CHECK_NOTNULL(node);
GELOGD("The size of out node is %zu", out_node_index.size());
for (size_t index = 0; index < out_node_index.size(); index++) {
auto node_out_anchor = node->GetOutDataAnchor(index);
if (node_out_anchor == nullptr) {
continue;
}

NodePtr out_node = out_node_index[index].first;
int32_t out_index = out_node_index[index].second;
GELOGD("Begin to handle output node:%s[%zu] with index:%zu", out_node->GetName().c_str(), out_index, index);
auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor.
GE_CHECK_NOTNULL(src_out_anchor);
for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) {
graphStatus ret = GraphUtils::RemoveEdge(node_out_anchor, dest_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "remove node's out anchor and peer in anchor failed.");
return FAILED;
}
ret = GraphUtils::AddEdge(src_out_anchor, dest_in_anchor);
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "link node's peer out anchor and out node's out anchor failed.");
return FAILED;
}
}
}
return SUCCESS;
}
} // namespace ge

+ 37
- 0
parser/common/parser_utils.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_PARSER_UTILS_H_
#define PARSER_COMMON_PARSER_UTILS_H_

#include "graph/graph.h"
#include "graph/node.h"
#include "external/ge/ge_api_error_codes.h"

namespace ge {
class ParserUtils {
public:
static Status ExpandOneToManyGraph(Graph &graph);

private:
static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph);
static Status HandleInputContext(const NodePtr &node,
const std::vector<NodePtr> &input_nodes,
const ComputeGraphPtr &compute_graph);
static Status HandleOutputContext(const NodePtr &node, const std::vector<std::pair<NodePtr, int32_t>> &out_node_index);
};
} // namespace ge
#endif // PARSER_COMMON_PARSER_UTILS_H_

+ 0
- 83
parser/common/pass_manager.cc View File

@@ -1,83 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/pass_manager.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/debug/log.h"
#include "graph/utils/node_utils.h"
#include "omg/omg_inner_types.h"

namespace ge {
namespace parser {
const vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { return names_to_graph_passes_; }

Status PassManager::AddPass(const string &pass_name, GraphPass *pass) {
GE_CHECK_NOTNULL(pass);
names_to_graph_passes_.emplace_back(pass_name, pass);
return SUCCESS;
}

Status PassManager::Run(const ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
return Run(graph, names_to_graph_passes_);
}

Status PassManager::Run(const ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &names_to_passes) {
GE_CHECK_NOTNULL(graph);
bool not_changed = true;

for (auto &pass_pair : names_to_passes) {
const auto &pass = pass_pair.second;
const auto &pass_name = pass_pair.first;
GE_CHECK_NOTNULL(pass);

PARSER_TIMESTAMP_START(PassRun);
Status status = pass->Run(graph);
if (status == SUCCESS) {
not_changed = false;
} else if (status != NOT_CHANGED) {
GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str());
return status;
}
for (const auto &subgraph :graph->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph);
GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str());
string subgraph_pass_name = pass_name + "::" + graph->GetName();
PARSER_TIMESTAMP_START(PassRunSubgraph);
status = pass->Run(subgraph);
PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str());
if (status == SUCCESS) {
not_changed = false;
} else if (status != NOT_CHANGED) {
GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str());
return status;
}
}
PARSER_TIMESTAMP_END(PassRun, pass_name.c_str());
}

return not_changed ? NOT_CHANGED : SUCCESS;
}

PassManager::~PassManager() {
for (auto &pass_pair : names_to_graph_passes_) {
auto &pass = pass_pair.second;
GE_DELETE_NEW_SINGLE(pass);
}
}
} // namespace parser
} // namespace ge

+ 0
- 76
parser/common/pass_manager.h View File

@@ -1,76 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_PASS_MANAGER_H_
#define PARSER_COMMON_PASS_MANAGER_H_

#include <vector>

#include "inc/graph_pass.h"

using std::vector;

namespace ge {
namespace parser {
///
/// @ingroup domi_omg
/// @brief pass manager
/// @author
///
class PassManager {
public:
///
/// get graph passes
/// @author
///
const vector<std::pair<std::string, GraphPass *>> &GraphPasses() const;

///
/// Add graph pass
/// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys.
/// @author
///
Status AddPass(const string &pass_name, GraphPass *pass);

///
/// Optimize graph with added pass
/// @param [inout] graph graph to be optimized
/// @return SUCCESS optimize successfully
/// @return NOT_CHANGED not optimized
/// @return others optimize failed
/// @author
///
Status Run(const ge::ComputeGraphPtr &graph);

///
/// Optimize graph with specified pass
/// @param [inout] graph graph to be optimized
/// @param [in] passes passes to be used
/// @return SUCCESS optimize successfully
/// @return NOT_CHANGED not optimized
/// @return others optimized failed
/// @author
///
static Status Run(const ge::ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &passes);

~PassManager();

private:
vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_;
};
} // namespace parser
} // namespace ge
#endif // PARSER_COMMON_PASS_MANAGER_H_

+ 4
- 5
parser/common/pre_checker.cc View File

@@ -23,7 +23,6 @@
#include "framework/common/debug/ge_log.h"
#include "omg/omg.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/model_saver.h"
#include "register/op_registry.h"

namespace ge {
@@ -56,7 +55,7 @@ void PreChecker::Init() {
fmk_op_types_ = nullptr;

// Currently only Caffe and tensorflow are supported
domi::FrameworkType fmk_type = GetParserContext().type;
domi::FrameworkType fmk_type = domi::GetContext().type;
if (fmk_type == domi::CAFFE)
fmk_op_types_ = &caffe_op_map;
else if (fmk_type == domi::TENSORFLOW)
@@ -119,8 +118,8 @@ FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflo

// If the user explicitly specifies the mapping relationship of the operator type through
// the -- OP_name_map parameter, the type specified by the user is used.
auto op_map_iter = GetParserContext().op_conf_map.find(type);
if (op_map_iter != GetParserContext().op_conf_map.end()) {
auto op_map_iter = domi::GetContext().op_conf_map.find(type);
if (op_map_iter != domi::GetContext().op_conf_map.end()) {
type = op_map_iter->second;
}

@@ -233,7 +232,7 @@ Status PreChecker::Save(string file) {
}

// Save JSON data to a file
GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed.");
GE_RETURN_WITH_LOG_IF_ERROR(ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed.");

return SUCCESS;
}


+ 1
- 1
parser/common/pre_checker.h View File

@@ -19,7 +19,7 @@

#include <string>
#include <vector>
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "omg/omg_inner_types.h"

namespace ge {


+ 0
- 3
parser/common/proto/insert_op.proto View File

@@ -45,9 +45,6 @@ message AippOpParams {
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。


+ 2
- 3
parser/common/proto_file_parser.cc View File

@@ -27,7 +27,6 @@
#include "common/types.h"
#include "common/util.h"
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"

@@ -159,7 +158,7 @@ bool SaveIdentifierOpMapInfo(const string &line, std::map<int, std::pair<string
}

bool CheckRealPath(const char *file_path) {
string dest_path = ge::parser::RealPath(file_path);
string dest_path = ge::RealPath(file_path);
if (dest_path.empty()) {
GELOGW("Path [%s] is not real existed.", file_path);
return false;
@@ -185,7 +184,7 @@ Status ProtoFileParser::CreatProtoFile() {
fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen);
}

int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP);
int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0640);
if (fd < kOpenRetValue) {
GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str());
return FAILED;


+ 15
- 9
parser/common/register_tbe.cc View File

@@ -19,7 +19,7 @@
#include <memory>
#include <string>
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h"
#include "common/op_map.h"
#include "common/util.h"
@@ -38,6 +38,8 @@ FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() {
}

bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_train) {
ge::OpTypeContainer::Instance()->Register(reg_data.GetOmOptype());

static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}};
if (is_train) {
op_map[domi::TENSORFLOW] = &tensorflow_train_op_map;
@@ -55,7 +57,8 @@ bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_tra
continue;
} else {
(*fmk_op_map)[tmp] = reg_data.GetOmOptype();
GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(),
GELOGD("First register in parser initilize, original type: %s, om_optype: %s, imply type: %s.",
tmp.c_str(),
reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str());
}
}
@@ -79,7 +82,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
return false;
}
std::shared_ptr<TensorFlowCustomParserAdapter> tf_parser_adapter =
ge::parser::MakeShared<TensorFlowCustomParserAdapter>();
ge::MakeShared<TensorFlowCustomParserAdapter>();
if (tf_parser_adapter == nullptr) {
GELOGE(PARAM_INVALID, "Create tf parser adapter failed.");
return false;
@@ -94,20 +97,22 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
return false;
}
GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str());
std::shared_ptr<TensorFlowFusionCustomParserAdapter> tf_fusion_parser_adapter =
ge::parser::MakeShared<TensorFlowFusionCustomParserAdapter>();
std::shared_ptr<TensorFlowFusionCustomParserAdapter>
tf_fusion_parser_adapter = ge::MakeShared<TensorFlowFusionCustomParserAdapter>();
if (tf_fusion_parser_adapter == nullptr) {
GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed.");
return false;
}
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
domi::TENSORFLOW, reg_data.GetOmOptype(),
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true);
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; },
true);
}
} else {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType());
if (factory == nullptr) {
GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.",
GELOGE(INTERNAL_ERROR,
"Get op parser factory for %s failed.",
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
return false;
}
@@ -119,12 +124,13 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {

PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType());
if (func == nullptr) {
GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.",
GELOGW("Get custom parser adapter failed for fmk type %s.",
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
return false;
}
OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func);
GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(),
GELOGD("Register custom parser adapter for op %s of fmk type %s success.",
reg_data.GetOmOptype().c_str(),
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
}
return true;


+ 0
- 212
parser/common/tbe_plugin_loader.cc View File

@@ -1,212 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tbe_plugin_loader.h"

#include <dirent.h>
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <string>

#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/string_util.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/utils/type_utils.h"
#include "parser/common/acl_graph_parser_util.h"

namespace ge {
std::map<string, string> TBEPluginLoader::options_ = {};

namespace {
const std::string FRAMEWORK_TYPE = "ge.frameworkType";
}

// Get Singleton Instance
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() {
static TBEPluginLoader instance_ptr_;
return instance_ptr_;
}

Status TBEPluginLoader::ClearHandles_() {
Status ret = SUCCESS;
for (const auto &handle : handles_vec_) {
if (dlclose(handle) != 0) {
ret = FAILED;
GELOGW("Failed to close handle: %s", dlerror());
}
}
handles_vec_.clear();
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() {
Status ret = ClearHandles_();
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo(
const std::map<string, string> &options) {
vector<string> file_list;
string caffe_parser_path;
std::string plugin_path;

options_ = options;
GetCustomOpPath(plugin_path);

// Whether there are files in the plugin so path
GetPluginSoFileList(plugin_path, file_list, caffe_parser_path);

// No file
if (file_list.empty()) {
// Print log
GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str());
}

GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted.");

// Load other so files except lib_caffe_parser.so in the plugin so path
for (auto elem : file_list) {
StringUtils::Trim(elem);

void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE);
if (handle == nullptr) {
GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
} else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) {
// Close dl when the program exist, not close here
GELOGI("Plugin load %s success.", elem.c_str());
handles_vec_.push_back(handle);
} else {
GELOGI("Plugin so has already been loaded, no need to load again.");
}
}
}

void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) {
GELOGI("Enter get custom op path schedule");
std::string fmk_type;
domi::FrameworkType type = domi::TENSORFLOW;
auto it = options_.find(FRAMEWORK_TYPE);
if (it != options_.end()) {
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10));
}
fmk_type = ge::TypeUtils::FmkTypeToSerialString(type);
GELOGI("Framework type is %s.", fmk_type.c_str());

const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
std::string path = path_env;
customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
GELOGI("Get custom so path from env : %s", path_env);
return;
}
std::string path_base = GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
}

string TBEPluginLoader::GetPath() {
Dl_info dl_info;
if (dladdr(reinterpret_cast<void *>(&TBEPluginLoader::GetPath), &dl_info) == 0) {
GELOGW("Failed to read so path!");
return string();
} else {
string so_path = dl_info.dli_fname;
char path[PATH_MAX] = {0};
if (so_path.length() >= PATH_MAX) {
GELOGW("File path is too long!");
return string();
}
if (realpath(so_path.c_str(), path) == nullptr) {
GELOGW("Failed to get realpath of %s", so_path.c_str());
return string();
}

so_path = path;
so_path = so_path.substr(0, so_path.rfind('/') + 1);
return so_path;
}
}

void TBEPluginLoader::GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path) {
// Support to split multiple so directories by ":"
vector<string> v_path = StringUtils::Split(path, ':');
for (size_t i = 0; i < v_path.size(); ++i) {
FindParserSo(v_path[i], file_list, caffe_parser_path);
GELOGI("CustomOpLib full name = %s", v_path[i].c_str());
}
}

void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) {
// Path, change to absolute path
string real_path = ge::parser::RealPath(path.c_str());
// Plugin path does not exist
if (real_path.empty()) {
GELOGW("RealPath is empty.");
return;
}
struct stat stat_buf;
if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) {
GELOGW("%s is not a dir.", real_path.c_str());
return;
}
struct dirent *dent(0);
DIR *dir = opendir(real_path.c_str());
// Plugin path does not exist
if (dir == nullptr) {
GELOGW("Open directory %s failed.", real_path.c_str());
return;
}

while ((dent = readdir(dir)) != nullptr) {
if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue;
string name = dent->d_name;
string full_name = real_path + "/" + name;
const string so_suff = ".so";
const string caffe_parser_so_suff = "lib_caffe_parser.so";
const string aicpu_so_suff = "_aicpu.so";
const string aicpu_host_so_suff = "_online.so";
if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) {
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff,
aicpu_host_so_suff);
} else {
FindParserSo(full_name, file_list, caffe_parser_path);
}
}
closedir(dir);
}

void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff) {
if (full_name.size() >= caffe_parser_so_suff.size() &&
full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(),
caffe_parser_so_suff) == 0) {
caffe_parser_path = full_name;
} else {
// Save parser so path into file_list vector
file_list.push_back(full_name);
}
}
} // namespace ge

+ 0
- 62
parser/common/tbe_plugin_loader.h View File

@@ -1,62 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_
#define PARSER_COMMON_TBE_PLUGIN_LOADER_H_

#include <dlfcn.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <type_traits>
#include <typeinfo>
#include <vector>

#include "external/ge/ge_api_error_codes.h"
#include "external/register/register.h"

namespace ge {
using SoHandlesVec = std::vector<void *>;
class TBEPluginLoader {
public:
Status Finalize();

// Get TBEPluginManager singleton instance
static TBEPluginLoader& Instance();

void LoadPluginSo(const std::map<string, string> &options);

static string GetPath();

private:
TBEPluginLoader() = default;
~TBEPluginLoader() = default;
Status ClearHandles_();
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff);
static void GetCustomOpPath(std::string &customop_path);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);

SoHandlesVec handles_vec_;
static std::map<string, string> options_;
};
} // namespace ge

#endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_

+ 0
- 78
parser/common/thread_pool.cc View File

@@ -1,78 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/thread_pool.h"

#include <atomic>
#include <functional>
#include <queue>
#include <stdexcept>
#include <utility>
#include <vector>

#include "register/register_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) {
idle_thrd_num_ = size < 1 ? 1 : size;

for (uint32_t i = 0; i < idle_thrd_num_; ++i) {
pool_.emplace_back(ThreadFunc, this);
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() {
is_stoped_.store(true);
{
std::unique_lock<std::mutex> lock{m_lock_};
cond_var_.notify_all();
}

for (std::thread &thd : pool_) {
if (thd.joinable()) {
try {
thd.join();
} catch (const std::system_error &) {
GELOGW("system_error");
} catch (...) {
GELOGW("exception");
}
}
}
}

void ThreadPool::ThreadFunc(ThreadPool *thread_pool) {
if (thread_pool == nullptr) {
return;
}
while (!thread_pool->is_stoped_) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock{thread_pool->m_lock_};
thread_pool->cond_var_.wait(
lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); });
if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) {
return;
}
task = std::move(thread_pool->tasks_.front());
thread_pool->tasks_.pop();
}
--thread_pool->idle_thrd_num_;
task();
++thread_pool->idle_thrd_num_;
}
}
} // namespace ge

+ 0
- 83
parser/common/thread_pool.h View File

@@ -1,83 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_THREAD_POOL_H_
#define PARSER_COMMON_THREAD_POOL_H_

#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <queue>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>

#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "external/ge/ge_api_error_codes.h"
#include "graph/types.h"
#include "parser/common/acl_graph_parser_util.h"

namespace ge {
using ThreadTask = std::function<void()>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool {
public:
explicit ThreadPool(uint32_t size = 4);
~ThreadPool();

template <class Func, class... Args>
auto commit(Func &&func, Args &&... args) -> std::future<decltype(func(args...))> {
GELOGD("commit run task enter.");
using retType = decltype(func(args...));
std::future<retType> fail_future;
if (is_stoped_.load()) {
GELOGE(ge::FAILED, "thread pool has been stopped.");
return fail_future;
}

auto bindFunc = std::bind(std::forward<Func>(func), std::forward<Args>(args)...);
auto task = ge::parser::MakeShared<std::packaged_task<retType()>>(bindFunc);
if (task == nullptr) {
GELOGE(ge::FAILED, "Make shared failed.");
return fail_future;
}
std::future<retType> future = task->get_future();
{
std::lock_guard<std::mutex> lock{m_lock_};
tasks_.emplace([task]() { (*task)(); });
}
cond_var_.notify_one();
GELOGD("commit run task end");
return future;
}

static void ThreadFunc(ThreadPool *thread_pool);

private:
std::vector<std::thread> pool_;
std::queue<ThreadTask> tasks_;
std::mutex m_lock_;
std::condition_variable cond_var_;
std::atomic<bool> is_stoped_;
std::atomic<uint32_t> idle_thrd_num_;
};
} // namespace ge

#endif // PARSER_COMMON_THREAD_POOL_H_

+ 2
- 2
parser/func_to_graph/proto_python_rule.mk View File

@@ -2,7 +2,7 @@ include $(BUILD_SYSTEM)/base_rules.mk
FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp
PROTO_SRC_DIR = parser/parser/func_to_graph/proto
PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto
PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto
$(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC))
@@ -14,4 +14,4 @@ $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC)
$(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP)
mkdir -p $@
cp -rf $(PY_PROTO_BUILD_DIR)/* $@
cp -rf $(PY_PROTO_BUILD_DIR)/* $@

+ 5
- 6
parser/module.mk View File

@@ -1,6 +1,6 @@

LOCAL_PATH := $(call my-dir)
include $(LOCAL_PATH)/stub/Makefile
include $(LOCAL_PATH)/../stub/Makefile
COMMON_LOCAL_C_INCLUDES := \
proto/om.proto \
proto/insert_op.proto \
@@ -39,9 +39,7 @@ COMMON_LOCAL_C_INCLUDES := \
$(TOPDIR)inc/external/graph \
$(TOPDIR)inc/external/parser \
$(TOPDIR)inc/framework \
$(TOPDIR)parser/parser \
$(TOPDIR)parser \
$(TOPDIR)graphengine/ge \
$(TOPDIR)framework/domi/parser \
libc_sec/include \
third_party/protobuf/include \
third_party/json/include \
@@ -115,6 +113,7 @@ LOCAL_SHARED_LIBRARIES := \
libparser_common \
libgraph \
libregister \
libge_common \
lib_caffe_parser \

LOCAL_LDFLAGS := -lrt
@@ -134,8 +133,8 @@ endif

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)

LOCAL_SRC_FILES := ../../out/parser/lib64/stub/tensorflow_parser.cc
LOCAL_SRC_FILES += ../../out/parser/lib64/stub/caffe_parser.cc
LOCAL_SRC_FILES := ../../../out/ge/lib64/stub/tensorflow_parser.cc
LOCAL_SRC_FILES += ../../../out/ge/lib64/stub/caffe_parser.cc


LOCAL_SHARED_LIBRARIES :=


+ 2
- 3
parser/onnx/module.mk View File

@@ -29,9 +29,7 @@ LOCAL_C_INCLUDES := \
$(TOPDIR)inc/external \
$(TOPDIR)inc/external/graph \
$(TOPDIR)inc/framework \
$(TOPDIR)parser \
$(TOPDIR)parser/parser \
$(TOPDIR)graphengine/ge \
$(TOPDIR)framework/domi/parser \
libc_sec/include \
third_party/protobuf/include \
third_party/json/include \
@@ -45,6 +43,7 @@ LOCAL_SHARED_LIBRARIES := \
libparser_common \
libgraph \
libregister \
libge_common \

LOCAL_LDFLAGS := -lrt



+ 1
- 2
parser/onnx/onnx_constant_parser.cc View File

@@ -17,7 +17,7 @@
#include "onnx_constant_parser.h"
#include <map>
#include <vector>
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/util.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/ge_tensor.h"
@@ -30,7 +30,6 @@ using ge::onnx::TensorProto;
using domi::ONNX;
using GeShape = ge::GeShape;
using GeTensorDesc = ge::GeTensorDesc;
using namespace ge::parser;

namespace ge {
Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) {


+ 0
- 1
parser/onnx/onnx_data_parser.cc View File

@@ -22,7 +22,6 @@
#include "parser/onnx/onnx_util.h"

using domi::ONNX;
using namespace ge::parser;

namespace ge {
Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) {


+ 12
- 9
parser/onnx/onnx_parser.cc View File

@@ -18,25 +18,24 @@
#include <algorithm>
#include <iostream>
#include "common/convert/pb2json.h"
#include "common/model_saver.h"
#include "common/util.h"
#include "external/graph/operator_factory.h"
#include "external/register/register_error_codes.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "framework/omg/parser/parser_types.h"
#include "omg/parser/parser_factory.h"
#include "onnx_op_parser.h"
#include "onnx_util.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/pre_checker.h"
#include "parser/common/acl_graph_parser_util.h"
#include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h"
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"

namespace ge {
namespace {
std::map<std::string, std::string> kOnnxOpMap = {
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kOpTypeInput, ge::DATA}, {ge::kOpTypeConstant, ge::CONSTANT},
};
}

@@ -256,9 +255,11 @@ Status OnnxModelParser::SetOperatorInputs() {
for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) {
auto out_iter = outputs_map_.find(in_iter->first);
if (out_iter == outputs_map_.end()) {
GELOGE(INTERNAL_ERROR, "Unknown input: %s:%d in node: %s", in_iter->first.c_str(), in_iter->second[0].second,
GELOGW("Unknown input: %s:%d for node: %s, which maybe option input.",
in_iter->first.c_str(),
in_iter->second[0].second,
in_iter->second[0].first.c_str());
return INTERNAL_ERROR;
continue;
}

std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second;
@@ -438,7 +439,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {

// 1. Get graph from onnx model file.
ge::onnx::ModelProto onnx_model;
if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) {
if (!ge::ReadProtoFromBinaryFile(file, &onnx_model)) {
GELOGE(PARAM_INVALID, "Read onnx model file failed.");
return FAILED;
}
@@ -535,6 +536,8 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) {
}
graph.SetInputs(input_ops).SetOutputs(output_indexs);

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));

UpdateFormat(graph);

GELOGI("Onnx model parser success.");
@@ -552,12 +555,12 @@ Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) {
}

ge::onnx::ModelProto onnx_model;
GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &onnx_model),
GE_RETURN_WITH_LOG_IF_FALSE(ge::ReadProtoFromBinaryFile(model_file, &onnx_model),
"ReadProtoFromBinaryFile failed, file:%s.", model_file);
ge::onnx::GraphProto graph_proto = onnx_model.graph();
nlohmann::json j;
ge::Pb2Json::Message2Json(graph_proto, std::set<std::string>(), j, true);
return ge::parser::ModelSaver::SaveJsonToFile(json_file, j);
return ge::ModelSaver::SaveJsonToFile(json_file, j);
}

ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) {


+ 0
- 3
parser/proto/insert_op.proto View File

@@ -45,9 +45,6 @@ message AippOpParams {
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。


+ 6
- 7
parser/tensorflow/graph_functiondef.cc View File

@@ -18,8 +18,7 @@
#include <iostream>
#include "common/fmk_error_codes.h"
#include "graph/debug/ge_attr_define.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/types.h"
#include "common/types_map.h"
#include "common/util.h"
#include "graph/op_desc.h"
@@ -219,7 +218,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph,

string op_name = anchor->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetIdx()) + "_retval";
ge::OpDescPtr op = nullptr;
GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::parser::NETOUTPUT), return FAILED);
GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::NETOUTPUT), return FAILED);
graphStatus status = op->AddInputDesc(ge::GeTensorDesc());
if (status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add input desc for op:%s failed.", op->GetName().c_str());
@@ -282,7 +281,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect
string op_name = anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" +
to_string(anchor->GetPeerOutAnchor()->GetIdx()) + "_arg";
ge::OpDescPtr op = nullptr;
GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::parser::DATA), return FAILED);
GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::DATA), return FAILED);
graphStatus status = op->AddOutputDesc(ge::GeTensorDesc());
if (status != GRAPH_SUCCESS) {
GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str());
@@ -330,7 +329,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph

for (const ge::NodePtr &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (node->GetOpDesc()->GetType() == ge::parser::DATA) {
if (node->GetOpDesc()->GetType() == ge::DATA) {
int64_t index = 0;

int64_t type = 1;
@@ -351,7 +350,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph
continue;
}

if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) {
if (node->GetOpDesc()->GetType() == ge::NETOUTPUT) {
int64_t index = 0;
int64_t type = 1;

@@ -475,7 +474,7 @@ domi::Status GraphToFunctionDef::BuildFunctionDef(ge::ComputeGraphPtr &graph, co
GE_CHECK_NOTNULL(library);
GE_CHECK_NOTNULL(call_node_def);
// Current date / time base on the current system
string now_time = ge::parser::CurrentTimeInStr();
string now_time = ge::CurrentTimeInStr();
static int i = 0;
const string name = name_in + now_time + to_string(i);
i++;


+ 1
- 1
parser/tensorflow/graph_insert_trans_op.h View File

@@ -21,7 +21,7 @@
#include <vector>
#include "common/fmk_types.h"
#include "common/op/ge_op_utils.h"
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "graph/compute_graph.h"
#include "graph/node.h"
#include "graph/types.h"


+ 128
- 153
parser/tensorflow/graph_optimizer.cc View File

@@ -23,13 +23,13 @@
#include "cce/cce.h"
#include "cce/dnn.h"
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/math/math_util.h"
#include "common/op/ge_op_utils.h"
#include "common/op_map.h"
#include "common/types.h"
#include "common/types_map.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_tensor.h"
@@ -39,7 +39,6 @@
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph_functiondef.h"
#include "parser/common/acl_graph_parser_util.h"
#include "proto/tensorflow/attr_value.pb.h"
#include "register/op_registry.h"

@@ -92,137 +91,117 @@ const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal";

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<string, OpSupportTranInfo> g_OpSupportTranInfo = {};

TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::MUL,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::MUL,
std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC,
ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}),
InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::L2LOSS,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::L2LOSS,
std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC,
ge::FORMAT_HWCN}), // inputformats
ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT)

TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined,
ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined,
ge::FORMAT_NC1HWC0, ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, ge::FORMAT_FRACTAL_Z,
ge::DT_FLOAT)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined,
ge::FORMAT_NC1HWC0, ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT)

TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0,
ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0,
ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0,
ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)

TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16,
ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined,
OutFmtSupportAsInput, OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined,
OutFmtSupportAsInput, OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined,
OutFmtSupportAsInput, OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput,
OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined,
OutFmtSupportAsInput, OutDtSupportUndefined)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput,
OutDtSupportUndefined)

TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16,
OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::APPLYRMSPROP,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::APPLYRMSPROP,
std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC,
ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}),
ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput,
TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput,
OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::STREAMMERGE,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::STREAMMERGE,
std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}),
InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::MEMCPYASYNC,
TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::MEMCPYASYNC,
std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}),
InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput)

bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) {
static bool fmtInited = false;
GE_IF_BOOL_EXEC(
!fmtInited, fmtInited = true;
if (domi::OpRegistry().Instance()->GetImplyType(ge::parser::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) {
auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::parser::MUL);
if (it != g_OpSupportTranInfo.end()) {
auto &fmts = it->second.inputFormats;
auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0);
fmts.erase(itFmt);
}
})
!fmtInited, fmtInited = true;
if (domi::OpRegistry().Instance()->GetImplyType(ge::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) {
auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::MUL);
if (it != g_OpSupportTranInfo.end()) {
auto &fmts = it->second.inputFormats;
auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0);
fmts.erase(itFmt);
}
})
string cceTbeOpType = "TBE";
GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN,
cceTbeOpType = "CCE";)
@@ -807,7 +786,7 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis
for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) {
tmp_dim = ge_desc->GetShape().GetDim(j);
GE_CHECK_GE(tmp_dim, 0);
PARSER_INT64_MULCHECK(real_size, tmp_dim);
FMK_INT64_MULCHECK(real_size, tmp_dim);
real_size *= tmp_dim;
}
ge::TensorUtils::SetSize(*ge_desc, real_size * size_type);
@@ -1198,7 +1177,7 @@ Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_p

char *buf = nullptr;
int32_t len = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::parser::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false,
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false,
"read bytes file error!");

GELOGI("len =%d\n", len);
@@ -1229,7 +1208,7 @@ Status ParserGraphOptimizer::MakeTfProtoDef() {
CreateIOListFuncMap(mOpIOListFuncMap);

for (ge::NodePtr n : graph_->GetDirectNode()) {
if (n->GetType() != ge::parser::FRAMEWORKOP) continue;
if (n->GetType() != ge::FRAMEWORKOP) continue;
std::string original_type;
GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true,
"get original type failed.");
@@ -1290,9 +1269,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>
bool hasGetNext = false;
for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue);
string type = "";
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
GE_CHK_STATUS_RET(GetOriginalType(node, type));
if (type == "IteratorGetNext") {
hasGetNext = true;
break;
@@ -1300,9 +1279,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>
}
for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue)
string type = "";
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
GE_CHK_STATUS_RET(GetOriginalType(node, type));
if (type == "IteratorGetNext") {
vector<NodePtr> temp_node_cluser;
for (auto in_anchor : node->GetAllInDataAnchors()) {
@@ -1338,9 +1317,9 @@ Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map<string, vector<Node
GE_CHECK_NOTNULL(node);
OpDescPtr temp_node_desc_ptr = node->GetOpDesc();
GE_CHECK_NOTNULL(temp_node_desc_ptr);
GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue);
GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::DATA_TYPE, continue);

if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE &&
if (temp_node_desc_ptr->GetType() == ge::FRAMEWORK_OP_TYPE &&
(temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) {
temp_node_cluser.push_back(node);
} else {
@@ -1421,7 +1400,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
return FAILED);

std::string type = "";
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type));
GE_CHK_STATUS_RET(GetOriginalType(nodes[0], type));
(void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);

(void)AttrUtils::SetZeroCopyBytes(
@@ -1431,7 +1410,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length()));

(void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type);
(void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, domi::GetContext().type);

// reconstruct fusion_node and edges
GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef),
@@ -1481,19 +1460,17 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge
}

InControlAnchorPtr node_in_control = node->GetInControlAnchor();
GE_IF_BOOL_EXEC(
node_in_control != nullptr, for (auto peer_out_anchor
: node_in_control->GetPeerOutControlAnchors()) {
vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode());
GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control));
});
GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_anchor
: node_in_control->GetPeerOutControlAnchors()) {
vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode());
GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control));
});
OutControlAnchorPtr node_out_control = node->GetOutControlAnchor();
GE_IF_BOOL_EXEC(
node_out_control != nullptr, for (auto peer_in_control_anchor
: node_out_control->GetPeerInControlAnchors()) {
vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode());
GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control));
});
GE_IF_BOOL_EXEC(node_out_control != nullptr, for (auto peer_in_control_anchor
: node_out_control->GetPeerInControlAnchors()) {
vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode());
GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control));
});
}
return SUCCESS;
}
@@ -1518,19 +1495,18 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr>
}

InControlAnchorPtr node_in_control = node->GetInControlAnchor();
GE_IF_BOOL_EXEC(
node_in_control != nullptr, for (auto peer_out_ctl_anchor
: node_in_control->GetPeerOutControlAnchors()) {
GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue);
NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()];
GE_IF_BOOL_EXEC(
ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS,
GELOGE(FAILED,
"LinkInnerAnchor Link control anchor failed, src node: "
"%s, dst node: %s.",
src_ctrl->GetName().c_str(), dst->GetName().c_str());
return FAILED);
});
GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_ctl_anchor
: node_in_control->GetPeerOutControlAnchors()) {
GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue);
NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()];
GE_IF_BOOL_EXEC(
ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS,
GELOGE(FAILED,
"LinkInnerAnchor Link control anchor failed, src node: "
"%s, dst node: %s.",
src_ctrl->GetName().c_str(), dst->GetName().c_str());
return FAILED);
});
}
return SUCCESS;
}
@@ -1881,24 +1857,24 @@ OpDescPtr ParserGraphOptimizer::CreateTranslateOp(enum ge::Format inFormat, enum
static uint32_t transop_count = 0;
OpDescPtr op_def = nullptr;
std::stringstream sstmp;
sstmp << "translate_" << ge::parser::TRANSDATA << "_" << transop_count++;
GE_MAKE_SHARED(op_def = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::TRANSLATE), op_def = nullptr;
sstmp << "translate_" << ge::TRANSDATA << "_" << transop_count++;
GE_MAKE_SHARED(op_def = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::TRANSLATE), op_def = nullptr;
return op_def);
GELOGI(
"create translate op:%s, input format:%s, input datatype:%s, output "
"format:%s, output datatype:%s.",
op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(),
ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(),
ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str());
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), return nullptr,
"SetInt ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), return nullptr,
"SetInt ATTR_NAME_INPUT_DATATYPE failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), return nullptr,
"SetInt ATTR_NAME_INPUT_DATATYPE failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), return nullptr,
"SetInt ATTR_NAME_INPUT_DATATYPE failed.");
"create translate op:%s, input format:%s, input datatype:%s, output "
"format:%s, output datatype:%s.",
op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(),
ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(),
ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str());
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat),
return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype),
return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat),
return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype),
return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed.");
if (inDatatype != ge::DT_FLOAT16) {
GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), return nullptr,
"create translate op:add input desc fail.");
@@ -1920,17 +1896,17 @@ OpDescPtr ParserGraphOptimizer::CreatePermuteOp(enum ge::Format input_format, en
static uint32_t transop_count = 0;

std::stringstream sstmp;
sstmp << "transdata_" << ge::parser::PERMUTE << "_" << transop_count++;
sstmp << "transdata_" << ge::PERMUTE << "_" << transop_count++;

OpDescPtr op_desc = nullptr;
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::PERMUTE), op_desc = nullptr;
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::PERMUTE), op_desc = nullptr;
return op_desc);
GELOGI("create permute op:%s", op_desc->GetName().c_str());

GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr,
"SetInt ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr,
"SetInt ATTR_NAME_OUTPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format),
return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format),
return nullptr, "SetInt ATTR_NAME_OUTPUT_FORMAT failed.");

GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1));
GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1));
@@ -1947,11 +1923,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type,
enum ge::Format format) {
static uint32_t transop_count = 0;
std::stringstream sstmp;
sstmp << "transdata_" << ge::parser::CAST << "_" << transop_count++;
sstmp << "transdata_" << ge::CAST << "_" << transop_count++;

OpDescPtr op_desc = nullptr;
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::CAST), op_desc = nullptr;
return op_desc);
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::CAST), op_desc = nullptr; return op_desc);
GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(),
ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(),
ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str());
@@ -1975,10 +1950,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type,
OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) {
static uint32_t transop_count = 0;
std::stringstream sstmp;
sstmp << "transdata_" << ge::parser::TRANSDATA << "_" << transop_count++;
sstmp << "transdata_" << ge::TRANSDATA << "_" << transop_count++;

OpDescPtr op_desc = nullptr;
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::TRANSDATA), op_desc = nullptr;
GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::TRANSDATA), op_desc = nullptr;
return op_desc);

GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(),
@@ -1989,10 +1964,10 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format)
output_format = FORMAT_NCHW;
}

GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr,
"SetInt of ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr,
"SetInt of ATTR_NAME_OUTPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format),
return nullptr, "SetInt of ATTR_NAME_INPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format),
return nullptr, "SetInt of ATTR_NAME_OUTPUT_FORMAT failed.");
GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr,
"create transdata op:add input desc fail.");
GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr,
@@ -2000,4 +1975,4 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format)

return op_desc;
}
} // namespace ge
} // namespace domi

+ 9
- 10
parser/tensorflow/graph_optimizer.h View File

@@ -20,7 +20,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "graph/anchor.h"
#include "graph/compute_graph.h"
#include "graph/node.h"
@@ -46,9 +46,8 @@ class ParserGraphOptimizer {
domi::Status FusionFmkop();

inline bool IsHCOMOp(const string &op_type) {
return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) ||
(op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) ||
(op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter");
return (op_type == ge::HCOMALLREDUCE) || (op_type == ge::HCOMALLGATHER) || (op_type == ge::HCOMBROADCAST) ||
(op_type == ge::HCOMSEND) || (op_type == ge::HCOMRECEIVE) || (op_type == "HcomReduceScatter");
}

void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; }
@@ -104,11 +103,11 @@ class ParserGraphOptimizer {
domi::Status UpdateGraph(vector<ge::NodePtr> &nodes);

domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge::NodePtr> &nodes,
vector<ge::InDataAnchorPtr> &input_anchors, vector<ge::OutDataAnchorPtr> &output_anchors,
map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map,
vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors,
unordered_map<string, ge::NodePtr> &node_map);
vector<ge::InDataAnchorPtr> &input_anchors, vector<ge::OutDataAnchorPtr> &output_anchors,
map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map,
vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors,
unordered_map<string, ge::NodePtr> &node_map);

domi::Status LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map);

@@ -124,5 +123,5 @@ class ParserGraphOptimizer {

domi::Status MakeTfProtoDef();
};
} // namespace ge
} // namespace domi
#endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_

+ 1
- 1
parser/tensorflow/iterator_fusion_pass.cc View File

@@ -19,7 +19,7 @@
#include <memory>

#include "common/debug/log.h"
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "common/util.h"
#include "graph_optimizer.h"
#include "framework/common/ge_inner_error_codes.h"


+ 0
- 3
parser/tensorflow/proto/insert_op.proto View File

@@ -45,9 +45,6 @@ message AippOpParams {
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;
// related_input_name is optional and the top name of data node which inserts aipp
string related_input_name = 6;
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。


+ 2
- 2
parser/tensorflow/scope/scope_pass_manager.cc View File

@@ -15,7 +15,7 @@
*/

#include "parser/tensorflow/scope/scope_pass_manager.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
@@ -25,7 +25,7 @@
namespace ge {
shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) {
GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr");
scope_graph_ = ge::parser::MakeShared<ScopeGraph>();
scope_graph_ = ge::MakeShared<ScopeGraph>();
if (scope_graph_ == nullptr) {
GELOGE(FAILED, "Scope graph make shared failed.");
return nullptr;


+ 2
- 3
parser/tensorflow/tensorflow_arg_parser.cc View File

@@ -17,7 +17,6 @@
#include "common/debug/log.h"
#include "parser/common/op_def/arg_op.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/compute_graph.h"
#include "graph/ge_tensor.h"
#include "parser/common/op_parser_factory.h"
@@ -45,7 +44,7 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) {
"trans output_attr_value failed, op: %s", node->name().c_str());

domi::tensorflow::AttrValue_ListValue attr_list = output_attr_value.list();
GetParserContext().format =
domi::GetContext().format =
static_cast<domi::tagDomiTensorFormat>(attr_list.func(0).attr().at(kSerializeFormat).i());
} else {
/// _Arg constructed from inference function do not has input_tensor_dec
@@ -65,5 +64,5 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) {
return SUCCESS;
}

DOMI_REGISTER_TENSORFLOW_PARSER(ge::parser::ARG, ArgOpOperator).SetParseParamsFn(ParseParams);
DOMI_REGISTER_TENSORFLOW_PARSER(ge::ARG, ArgOpOperator).SetParseParamsFn(ParseParams);
} // namespace ge

+ 0
- 4
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -16,7 +16,6 @@

#include "tensorflow_auto_mapping_parser_adapter.h"

#include "framework/omg/parser/parser_types.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/op_parser_factory.h"
@@ -25,9 +24,6 @@


using domi::TENSORFLOW;
using namespace ge::parser;

using ge::parser::PLACEHOLDERWITHDEFAULT;

namespace ge {
namespace {


+ 2
- 4
parser/tensorflow/tensorflow_constant_parser.cc View File

@@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h"
#include "parser/common/op_def/constant_op.h"
#include "parser/common/op_def/ir_pb_converter.h"
@@ -27,12 +27,10 @@
#include "graph/ge_tensor.h"
#include "graph/utils/attr_utils.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"
#include "register/tensor_assign.h"

using domi::tensorflow::NodeDef;
using domi::TENSORFLOW;
using ge::parser::CONSTANTOP;

namespace ge {
Status TensorFlowConstantParser::ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op) {
@@ -68,7 +66,7 @@ Status TensorFlowConstantParser::ParseValue(const domi::tensorflow::NodeDef *nod

const domi::tensorflow::TensorProto &tensor = attr_value.tensor();

GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>();
GeTensorPtr weight = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(weight);
int64_t dataType = 0;
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(opDesc, TENSORFLOW_ATTR_DTYPE, dataType), INTERNAL_ERROR,


+ 2
- 5
parser/tensorflow/tensorflow_data_parser.cc View File

@@ -19,14 +19,11 @@
#include "common/debug/log.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"

using domi::tensorflow::AttrValue;
using domi::tensorflow::NodeDef;
using domi::TENSORFLOW;
using ge::parser::DATA;

namespace ge {
namespace {
@@ -100,7 +97,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, ge::OpDe
Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def) {
GE_CHECK_NOTNULL(op_def);
(void)op_src;
const ge::ParserContext &ctx = GetParserContext();
const ge::OmgContext &ctx = domi::GetContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
// User not designate the input_shape
std::string name = op_def->GetName();
@@ -134,7 +131,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge:
}

Status TensorFlowDataParser::CheckInputShape(const std::string &name) {
const ge::ParserContext &ctx = GetParserContext();
const ge::OmgContext &ctx = domi::GetContext();
if (!ctx.is_dynamic_input) {
for (uint32_t i = 0; i < user_input_dims_v.size(); i++) {
// if input_shape has some placeholders, user should designate them.


+ 0
- 3
parser/tensorflow/tensorflow_enter_parser.cc View File

@@ -19,11 +19,8 @@
#include "framework/common/debug/log.h"
#include "graph/debug/ge_attr_define.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"

using domi::TENSORFLOW;
using ge::parser::ENTER;
using ge::parser::REFENTER;

namespace ge {
Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) {


+ 2
- 7
parser/tensorflow/tensorflow_fill_parser.cc View File

@@ -20,11 +20,6 @@
#include "parser/common/op_def/fill_op.h"
#include "common/util.h"
#include "parser/tensorflow/tensorflow_parser_register.h"
#include "framework/omg/parser/parser_types.h"

using ge::parser::ALPHA_DEFAULT_VALUE;
using ge::parser::BETA_DEFAULT_VALUE;
using ge::parser::FILL;

namespace ge {
/*
@@ -58,8 +53,8 @@ domi::Status ParseParams(const NodeDef *node, FillOperator *op) {

op->DataType(type);

op->Alpha(ge::parser::ALPHA_DEFAULT_VALUE);
op->Beta(ge::parser::BETA_DEFAULT_VALUE);
op->Alpha(ge::ALPHA_DEFAULT_VALUE);
op->Beta(ge::BETA_DEFAULT_VALUE);

return domi::SUCCESS;
}


+ 1
- 2
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -18,15 +18,14 @@
#include "parser/common/op_def/frameworkop_op.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/tensorflow/tensorflow_op_parser.h"
#include "parser/tensorflow/tensorflow_parser_register.h"
#include "proto/tensorflow/tensor_shape.pb.h"

using domi::tensorflow::TensorShapeProto;
using domi::tensorflow::AttrValue;
using ge::FRAMEWORKOP;
using domi::TENSORFLOW;
using ge::parser::FRAMEWORKOP;

namespace ge {
Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {


+ 4
- 4
parser/tensorflow/tensorflow_fusion_op_parser.cc View File

@@ -17,11 +17,11 @@
#include "parser/tensorflow/tensorflow_fusion_op_parser.h"
#include <memory>
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/fp16_t.h"
#include "common/ge/ge_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "omg/omg.h"
#include "parser/common/parser_fp16_t.h"
#include "parser/tensorflow/tensorflow_op_parser.h"
#include "register/tensor_assign.h"

@@ -115,7 +115,7 @@ Status TensorFlowFusionOpParser::ParseHalfFromConst(const NodeDef *node_def, flo
auto val_vec = tensor.half_val();
int32_t val_size = val_vec.size();
if (index < val_size) {
ge::parser::fp16_t fp16_value = static_cast<parser::fp16_t>(val_vec.Get(index));
fp16_t fp16_value = static_cast<fp16_t>(val_vec.Get(index));
param = fp16_value.ToFloat();
} else {
GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index:%d, not supported.", index);
@@ -132,7 +132,7 @@ Status TensorFlowFusionOpParser::ParseWeightFromConst(const NodeDef *node_def, g
GE_CHECK_NOTNULL(node_def);
TensorProto tensor;
GE_CHK_STATUS_RET(GetTensorFromNode(node_def, tensor), "get tensor failed.");
weight = ge::parser::MakeShared<ge::GeTensor>();
weight = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(weight);
domi::tensorflow::DataType data_type = tensor.dtype();
GE_CHK_STATUS_RET(


+ 18
- 19
parser/tensorflow/tensorflow_fusionop_util.cc View File

@@ -20,7 +20,6 @@
#include "common/op/ge_op_utils.h"
#include "framework/common/debug/ge_log.h"
#include "parser/tensorflow/tensorflow_parser.h"
#include "framework/omg/parser/parser_types.h"

#include <iostream>
#include <cstdlib>
@@ -114,21 +113,21 @@ static map<string, string> tensorflow_fusionop_map = {

// <Types of fusion operators, Number of children operators>
static map<string, vector<int>> tensorflow_fusionop_children_nums_map = {
{ge::parser::CLIPBOXES, {8}},
{ge::parser::FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}},
{ge::parser::RPNPROPOSALS, {75, 85, 97}},
{ge::parser::DECODEBBOX, {24, 28}},
{ge::parser::ROIALIGN, {82, 83, 84}},
{ge::parser::FUSIONBATCHNORM, {8}},
{ge::parser::GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the
{ge::parser::HUBERLOSSGRAD, {8, 9, 10, 20, 21}},
{CLIPBOXES, {8}},
{FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}},
{RPNPROPOSALS, {75, 85, 97}},
{DECODEBBOX, {24, 28}},
{ROIALIGN, {82, 83, 84}},
{FUSIONBATCHNORM, {8}},
{GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the
{HUBERLOSSGRAD, {8, 9, 10, 20, 21}},
};

// <Types of fusion operators, Name of children operators(Remove the prefixes and/)>
static map<string, vector<string>> tensorflow_fusionop_children_names_map = {
{ge::parser::FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}},
{ge::parser::GETSPAN, {}},
{ge::parser::HUBERLOSSGRAD, {}},
{FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}},
{GETSPAN, {}},
{HUBERLOSSGRAD, {}},
};

// ----------------------------Index table of input and output of fusion operator--------------
@@ -138,23 +137,23 @@ static map<string, vector<string>> tensorflow_fusionop_children_names_map = {
// Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored.
// If it is control edge input, the index is graph::kControlSlot(-1).
static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_inputs_map = {
{ge::parser::FUSIONBATCHNORM,
{FUSIONBATCHNORM,
{{"mul_1", {0, kFusionDisableIndex}},
{"mul", {1, 1}},
{"sub", {2, kFusionDisableIndex}},
{"mul_2", {3, kFusionDisableIndex}},
{"add", {4, kFusionDisableIndex}}}},
{ge::parser::GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}},
{ge::parser::HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}},
{GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}},
{HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}},
};

static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_outputs_map = {
{ge::parser::FUSIONBATCHNORM, {{"add_1", {0}}}},
{ge::parser::GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}},
{ge::parser::HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}},
{FUSIONBATCHNORM, {{"add_1", {0}}}},
{GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}},
{HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}},
};
map<string, vector<std::pair<string, uint32_t>>> tensorflow_fusionop_input_const_weight_index_map = {
{ge::parser::FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}},
{FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}},
};

// Can a string be converted to an integer


+ 1
- 1
parser/tensorflow/tensorflow_fusionop_util.h View File

@@ -22,7 +22,7 @@
#include <vector>
#include "common/debug/log.h"
#include "common/string_util.h"
#include "framework/omg/parser/parser_types.h"
#include "common/types.h"
#include "common/util.h"
#include "omg/omg_inner_types.h"
#include "proto/tensorflow/graph.pb.h"


+ 0
- 4
parser/tensorflow/tensorflow_identity_parser.cc View File

@@ -17,15 +17,11 @@
#include "common/op/ge_op_utils.h"
#include "common/op_def/ir_pb_converter.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"

#include "parser/tensorflow/tensorflow_identity_parser.h"

using domi::TENSORFLOW;
using ge::parser::IDENTITY;
using ge::parser::READVARIABLEOP;

namespace ge {
REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowIdentityParser);
REGISTER_OP_PARSER_CREATOR(TENSORFLOW, READVARIABLEOP, TensorFlowIdentityParser);
} // namespace ge

+ 0
- 2
parser/tensorflow/tensorflow_merge_parser.cc View File

@@ -20,10 +20,8 @@
#include "framework/common/util.h"
#include "graph/debug/ge_attr_define.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"

using domi::TENSORFLOW;
using ge::parser::MERGE;

namespace ge {
Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) {


+ 0
- 1
parser/tensorflow/tensorflow_no_op_parser.cc View File

@@ -22,7 +22,6 @@
#include "parser/common/op_parser_factory.h"

using domi::TENSORFLOW;
using namespace ge::parser;

namespace ge {
Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {


+ 231
- 226
parser/tensorflow/tensorflow_parser.cc
File diff suppressed because it is too large
View File


+ 3
- 3
parser/tensorflow/tensorflow_parser_register.h View File

@@ -25,7 +25,7 @@
#include "framework/omg/parser/op_parser.h"
#include "parser/common/op_def/ir_pb_converter.h"
#include "parser/common/op_def/operator.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/ge/ge_util.h"
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_op_parser.h"
#include "proto/tensorflow/node_def.pb.h"
@@ -72,7 +72,7 @@ class TensorflowParserBuilder : public TensorflowWeightParserBuilder {
}

bool Finalize() override {
auto op_parser_adapter = ge::parser::MakeShared<TensorflowOpParserAdapter<Param>>(*this);
auto op_parser_adapter = ge::MakeShared<TensorflowOpParserAdapter<Param>>(*this);
if (op_parser_adapter == nullptr) {
GELOGE(FAILED, "Op parser adapter is null.");
}
@@ -102,7 +102,7 @@ class TensorflowOpParserAdapter : public TensorFlowOpParser {
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override {
const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node);
std::shared_ptr<Param> param = ge::parser::MakeShared<Param>();
std::shared_ptr<Param> param = ge::MakeShared<Param>();
if (param == nullptr) {
GELOGE(domi::FAILED, "Param is null");
return domi::FAILED;


+ 0
- 1
parser/tensorflow/tensorflow_ref_switch_parser.cc View File

@@ -26,7 +26,6 @@ using domi::tensorflow::DT_FLOAT;
using domi::tensorflow::AttrValue;
using domi::tensorflow::NodeDef;
using domi::TENSORFLOW;
using namespace ge::parser;

namespace ge {
// AUTO GEN PLEASE DO NOT MODIFY IT


+ 3
- 4
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -22,10 +22,9 @@
#include "graph/utils/type_utils.h"
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_util.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/math/math_util.h"

using domi::TENSORFLOW;
using namespace ge::parser;

namespace ge {
Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) {
@@ -48,7 +47,7 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
FMK_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u",
@@ -68,7 +67,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr
domi::tensorflow::AttrValue output_attr_value;

GE_IF_BOOL_EXEC(
GetParserContext().train_flag == true,
domi::GetContext().train_flag == true,

ge::GeTensorDesc input_desc;
ge::GeTensorDesc output_desc;


+ 0
- 1
parser/tensorflow/tensorflow_shape_n_parser.cc View File

@@ -26,7 +26,6 @@ using domi::tensorflow::AttrValue;
using domi::tensorflow::DataType;
using domi::tensorflow::DT_FLOAT;
using domi::tensorflow::DT_INT32;
using namespace ge::parser;

namespace {
const std::string kShapeAttrDtype = "out_type";


+ 4
- 6
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -22,16 +22,14 @@
#include "framework/common/op/attr_value_util.h"
#include "framework/common/op/op_parser_util.h"
#include "framework/common/util.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/utils/type_utils.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/math/math_util.h"

using domi::tensorflow::AttrValue;
using std::vector;
using std::shared_ptr;
using domi::TENSORFLOW;
using namespace ge::parser;

namespace ge {
Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) {
@@ -52,10 +50,10 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
tmp_dim = ge_desc.GetShape().GetDim(j);
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
PARSER_INT64_MULCHECK(real_size, tmp_dim);
FMK_INT64_MULCHECK(real_size, tmp_dim);
real_size *= tmp_dim;
}
PARSER_INT64_MULCHECK(real_size, size_type);
FMK_INT64_MULCHECK(real_size, size_type);
ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u",
@@ -112,7 +110,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
domi::tensorflow::AttrValue output_attr_value;

GE_IF_BOOL_EXEC(
GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc;
domi::GetContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc;

if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) {
GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed");


+ 10
- 11
parser/tensorflow/tensorflow_util.cc View File

@@ -15,25 +15,25 @@
*/

#include "parser/tensorflow/tensorflow_util.h"
#include <cstdint>
#include <cstdlib>
#include <cstdint>
#include <iostream>
#include <memory>
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "framework/common/op/ge_op_utils.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_tensor.h"
#include "graph/utils/type_utils.h"
#include "parser/tensorflow/tensorflow_op_parser.h"
#include "common/math/math_util.h"

using domi::tensorflow::DT_INVALID;

namespace ge {
using AttrValueMap = ::google::protobuf::Map<string, domi::tensorflow::AttrValue>;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue(
const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) {
const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) {
GE_CHECK_NOTNULL(node_def);
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr();
const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name);
@@ -46,7 +46,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType(
const domi::tensorflow::AttrValue &attr_value, const string &type) {
const domi::tensorflow::AttrValue &attr_value, const string &type) {
uint32_t num_set = 0;
#define VALIDATE_FIELD(name, type_string, oneof_case) \
do { \
@@ -118,7 +118,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType(
const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) {
const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) {
GE_CHECK_NOTNULL(node_src);

string node_name = node_src->name();
@@ -138,7 +138,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromAttrValueList(
ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) {
ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) {
const std::string SERIALIZE_FORMAT = "serialize_format";
const std::string SERIALIZE_DATATYPE = "serialize_datatype";
const std::string SERIALIZE_SHAPE = "serialize_shape";
@@ -162,7 +162,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor(
const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) {
const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) {
GE_CHECK_NOTNULL(op);
if (!attr_value.has_list()) {
return PARAM_INVALID;
@@ -191,9 +191,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr
// The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate.
// Here, special treatment is given to the two operators.
// Adjust shape to fit resnet50 network only.
GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape());
break;);
GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), vector<int64_t> data_dim = {tmp_dim};
GE_IF_BOOL_EXEC((type == ge::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); break;);
GE_IF_BOOL_EXEC((type == ge::MEAN) && (tmp_dim == 0), vector<int64_t> data_dim = {tmp_dim};
ge_desc.SetShape(ge::GeShape(data_dim)); break;);
}
ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
@@ -215,7 +214,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr
return SUCCESS;
}
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr(
const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) {
const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) {
GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null.");
node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value));
}


+ 1
- 1
parser/tensorflow/tensorflow_util.h View File

@@ -26,7 +26,7 @@
#include "external/graph/attr_value.h"
#include "external/graph/graph.h"
#include "external/graph/operator.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/types.h"
#include "framework/omg/omg_inner_types.h"
#include "graph/compute_graph.h"
#include "graph/ge_tensor.h"


+ 0
- 2
parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc View File

@@ -22,8 +22,6 @@
#include "parser/tensorflow/tensorflow_op_parser.h"
#include "parser/tensorflow/tensorflow_parser_register.h"

using namespace ge::parser;

namespace ge {
Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) {
GE_CHECK_NOTNULL(op_src);


+ 0
- 1
parser/tensorflow/tensorflow_variable_v2_parser.cc View File

@@ -32,7 +32,6 @@
using domi::tensorflow::AttrValue;
using domi::tensorflow::NodeDef;
using domi::tensorflow::TensorShapeProto;
using namespace ge::parser;

namespace ge {
const std::string SERIALIZE_FORMAT = "serialize_format";


Loading…
Cancel
Save