diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dd5e1e1..021e7798 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,8 +16,11 @@ endif() if(DEFINED ENV{D_PKG_SERVER}) set(GE_PB_PKG $ENV{D_PKG_SERVER}) - message("Download packages from PKG server") -endif() + message("Download packages from DPKG server") +elseif(DEFINED ENV{MSLIBS_SERVER}) + set(GE_PB_PKG "http://$ENV{MSLIBS_SERVER}:8081") + message("Download packages from MSPKG server") +endif () set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) @@ -37,7 +40,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) if (ENABLE_OPEN_SRC) - set(HI_PYTHON python3.7) + set(HI_PYTHON python3) include(cmake/external_libs/protobuf_shared.cmake) include(cmake/external_libs/protobuf_static.cmake) @@ -71,7 +74,7 @@ if (ENABLE_OPEN_SRC) set(STATIC_ACL_LIB ${GE_LIB_PATH}) find_module(slog libslog.so ${GE_LIB_PATH}) find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) - find_module(msprof libmsprof.so ${GE_LIB_PATH}) + find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) find_module(hccl libhccl.so ${GE_LIB_PATH}) find_module(adump_server libadump_server.a ${GE_LIB_PATH}) find_module(runtime libruntime.so ${GE_LIB_PATH}) @@ -80,20 +83,19 @@ if (ENABLE_OPEN_SRC) find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) find_module(error_manager_static liberror_manager.a ${GE_LIB_PATH}) - find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) + find_module(msprofiler_fwk libmsprofiler_fwk.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() find_module(slog libslog.so ${ASCEND_ATC_DIR} ${ASCEND_DRIVER_COMMON_DIR}) find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR} ${ASCEND_RUNTIME_DIR}) if(PLATFORM STREQUAL "train") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_RUNTIME_DIR}) + find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) if(PRODUCT STREQUAL "flr3") message(FATAL_ERROR "This platform is not supported in train mode, build terminated") @@ -106,20 +108,17 @@ if (ENABLE_OPEN_SRC) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) - #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) + #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) if(PRODUCT STREQUAL "flr3") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_SHARE_DIR}) elseif(PRODUCT STREQUAL "flr1") find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) elseif(PRODUCT STREQUAL "flr2") # flr2 ascend_hal_stub limsprof ? else() find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) endif() elseif(PLATFORM STREQUAL "all") - find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) + find_module(msprofiler libmsprofiler.a ${ASCEND_DRIVER_COMMON_DIR}) find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) @@ -127,14 +126,14 @@ if (ENABLE_OPEN_SRC) find_module(resource libresource.so ${ASCEND_ATC_DIR}) find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) find_module(error_manager_static liberror_manager.a ${ASCEND_ACL_DIR}) - find_module(msprofiler libmsprofiler.a ${ASCEND_ACL_DIR}) + find_module(msprofiler_fwk libmsprofiler_fwk.a ${ASCEND_ACL_DIR}) find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) else() - message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") + message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") endif() - if (ENABLE_GE_COV OR ENABLE_GE_UT) + if (ENABLE_GE_COV OR ENABLE_GE_UT) add_subdirectory(tests) endif() diff --git a/build.sh b/build.sh index 3c9a537e..a112fdaa 100644 --- a/build.sh +++ b/build.sh @@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/" usage() { echo "Usage:" - echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off]" + echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off] [-M]" echo "" echo "Options:" echo " -h Print usage" @@ -35,6 +35,7 @@ usage() echo " -p Build inference or train" echo " -v Display build command" echo " -S Enable enable download cmake compile dependency from gitee , default off" + echo " -M build MindSpore mode" echo "to be continued ..." } @@ -62,8 +63,9 @@ checkopts() PLATFORM="" PRODUCT="normal" ENABLE_GITEE="off" + MINDSPORE_MODE="off" # Process the options - while getopts 'ustchj:p:g:vS:' opt + while getopts 'ustchj:p:g:vS:M' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -104,6 +106,9 @@ checkopts() ENABLE_GITEE="$OPTARG" echo "enable download from gitee" ;; + M) + MINDSPORE_MODE="on" + ;; *) echo "Undefined option: ${opt}" usage @@ -113,8 +118,8 @@ checkopts() } checkopts "$@" -git submodule update --init metadef -git submodule update --init parser +#git submodule update --init metadef +#git submodule update --init parser mk_dir() { local create_dir="$1" # the target to make @@ -150,7 +155,13 @@ build_graphengine() if [[ "X$ENABLE_GITEE" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" fi - CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH} -DPLATFORM=${PLATFORM} -DPRODUCT=${PRODUCT}" + + if [[ "X$MINDSPORE_MODE" = "Xoff" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH} -DPLATFORM=${PLATFORM} -DPRODUCT=${PRODUCT}" + else + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_D=ON -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}" + fi + echo "${CMAKE_ARGS}" cmake ${CMAKE_ARGS} .. if [ $? -ne 0 ] @@ -169,6 +180,9 @@ build_graphengine() elif [ "X$ENABLE_GE_UT" = "Xon" ] then TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" + elif [ "X$MINDSPORE_MODE" = "Xon" ] + then + TARGET="ge_common graph" elif [ "x${PLATFORM}" = "xall" ] then # build all the target @@ -314,7 +328,13 @@ generate_package() fi } -if [[ "X$ENABLE_GE_UT" = "Xoff" ]]; then +if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then generate_package + echo "---------------- GraphEngine package archive generated ----------------" +elif [ "X$MINDSPORE_MODE" = "Xon" ] +then + cd "${OUTPUT_PATH}" + find ./ -name graphengine_lib.tar -exec rm {} \; + tar -cf graphengine_lib.tar lib fi -echo "---------------- GraphEngine package archive generated ----------------" + diff --git a/cmake/external_libs/gflags.cmake b/cmake/external_libs/gflags.cmake index f3f0f0ef..50cfb2bc 100755 --- a/cmake/external_libs/gflags.cmake +++ b/cmake/external_libs/gflags.cmake @@ -23,6 +23,7 @@ ExternalProject_Add(gflags_build URL ${REQ_URL} #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #SOURCE_DIR ${GE_CODE_DIR}/../../third_party/gflags/src/gflags-2.2.2 + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags BUILD_COMMAND $(MAKE) INSTALL_COMMAND $(MAKE) install diff --git a/cmake/external_libs/gtest.cmake b/cmake/external_libs/gtest.cmake index 96ea84b4..c5edcd72 100755 --- a/cmake/external_libs/gtest.cmake +++ b/cmake/external_libs/gtest.cmake @@ -10,7 +10,10 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() -if (ENABLE_GITEE) +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/gtest/release-1.8.0.tar.gz") + set(MD5 "") +elseif (ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") set(MD5 "") else() @@ -22,8 +25,9 @@ set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack- set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") ExternalProject_Add(gtest_build URL ${REQ_URL} + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest - -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON + -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON BUILD_COMMAND $(MAKE) INSTALL_COMMAND $(MAKE) install EXCLUDE_FROM_ALL TRUE diff --git a/cmake/external_libs/json.cmake b/cmake/external_libs/json.cmake index ce473d4b..3c1cd012 100755 --- a/cmake/external_libs/json.cmake +++ b/cmake/external_libs/json.cmake @@ -5,10 +5,14 @@ endif() include(ExternalProject) set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) -if (ENABLE_GITEE) - set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") - set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") - set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/ge_nlohmann_json/include.zip") + set(MD5 "0dc903888211db3a0f170304cd9f3a89") + set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) +#elseif (ENABLE_GITEE) +# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") +# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") +#set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") else() set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") set(MD5 "0dc903888211db3a0f170304cd9f3a89") @@ -18,6 +22,7 @@ ExternalProject_Add(json_build URL ${REQ_URL} #URL /home/txd/workspace/cloud_code/pkg/include.zip SOURCE_DIR ${JSON_SRC_DIR} + TLS_VERIFY OFF CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" diff --git a/cmake/external_libs/onnx.cmake b/cmake/external_libs/onnx.cmake index 9dadb544..1ee80d2d 100755 --- a/cmake/external_libs/onnx.cmake +++ b/cmake/external_libs/onnx.cmake @@ -6,7 +6,10 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) -if (ENABLE_GITEE) +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/onnx/onnx-1.6.0.tar.gz") + set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") +elseif (ENABLE_GITEE) set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") set(MD5 "1bdbcecdd68ea8392630467646776e02") else() @@ -19,6 +22,7 @@ ExternalProject_Add(onnx #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 #SOURCE_DIR ${ONNX_SRC_DIR} + TLS_VERIFY OFF CONFIGURE_COMMAND "" BUILD_COMMAND "" #INSTALL_COMMAND "" diff --git a/cmake/external_libs/protobuf_shared.cmake b/cmake/external_libs/protobuf_shared.cmake index c9c6b7d9..6334c8a3 100755 --- a/cmake/external_libs/protobuf_shared.cmake +++ b/cmake/external_libs/protobuf_shared.cmake @@ -26,6 +26,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") ExternalProject_Add(protobuf_build URL ${REQ_URL} + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} diff --git a/cmake/external_libs/protobuf_static.cmake b/cmake/external_libs/protobuf_static.cmake index 6f3e1f53..e4bbb9a0 100755 --- a/cmake/external_libs/protobuf_static.cmake +++ b/cmake/external_libs/protobuf_static.cmake @@ -27,6 +27,7 @@ ExternalProject_Add(protobuf_static_build URL ${REQ_URL} #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/cmake/external_libs/protoc.cmake b/cmake/external_libs/protoc.cmake index 0d162c0d..9ea1aced 100755 --- a/cmake/external_libs/protoc.cmake +++ b/cmake/external_libs/protoc.cmake @@ -30,6 +30,7 @@ ExternalProject_Add(protoc_build URL ${REQ_URL} #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz #SOURCE_DIR ${GE_CODE_DIR}/../third_party/protobuf/src/protobuf-3.8.0 + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc /cmake BUILD_COMMAND $(MAKE) INSTALL_COMMAND $(MAKE) install diff --git a/cmake/external_libs/securec.cmake b/cmake/external_libs/securec.cmake index 0bd62ab2..0f8b6d3a 100755 --- a/cmake/external_libs/securec.cmake +++ b/cmake/external_libs/securec.cmake @@ -10,11 +10,20 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") endif() +if (GE_PB_PKG) + set(REQ_URL "${GE_PB_PKG}/libs/securec/v1.1.10.tar.gz") + set(MD5 "") +else() + set(REQ_URL "https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz") + set(MD5 "") +endif () + ExternalProject_Add(c_sec_build - URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz - #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + URL ${REQ_URL} + #URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz #SOURCE_DIR ${GE_CODE_DIR}/../libc_sec PATCH_COMMAND patch -p1 < ${GE_CODE_DIR}/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch + TLS_VERIFY OFF CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index 88a5c52f..85a1bd18 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -60,6 +60,8 @@ set(TRAIN_SRC_LIST "common/dump/dump_manager.cc" "common/dump/dump_properties.cc" "common/dump/dump_op.cc" + "common/profiling/ge_profiling.cc" + "common/profiling/ge_runner_profiling.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" "generator/ge_generator.cc" @@ -142,6 +144,7 @@ set(TRAIN_SRC_LIST "graph/passes/atomic_addr_clean_pass.cc" "graph/passes/mark_same_addr_pass.cc" "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/dynamic_single_op_reset_shape_pass.cc" "graph/passes/mark_agnostic_pass.cc" "graph/partition/dynamic_shape_partition.cc" "graph/partition/stage_partition.cc" @@ -201,6 +204,7 @@ set(TRAIN_SRC_LIST "host_kernels/sub_kernel.cc" "host_kernels/transdata_kernel.cc" "host_kernels/unpack_kernel.cc" + "host_kernels/reformat_kernel.cc" "graph/passes/folding_pass.cc" "graph/passes/get_original_format_pass.cc" "graph/passes/guarantee_const_pass.cc" @@ -331,7 +335,6 @@ set(TRAIN_SRC_LIST "hybrid/hybrid_davinci_model.cc" "executor/ge_executor.cc" "client/ge_api.cc" - "client/ge_prof.cc" "analyzer/analyzer.cc" "ir_build/ge_ir_build.cc" "ir_build/atc_ir_common.cc" @@ -432,6 +435,7 @@ set(INFER_SRC_LIST "graph/passes/atomic_addr_clean_pass.cc" "graph/passes/mark_same_addr_pass.cc" "graph/passes/mark_graph_unknown_status_pass.cc" + "graph/passes/dynamic_single_op_reset_shape_pass.cc" "graph/passes/mark_agnostic_pass.cc" "graph/common/omg_util.cc" "graph/common/bcast.cc" @@ -487,6 +491,7 @@ set(INFER_SRC_LIST "host_kernels/slice_d_kernel.cc" "host_kernels/dynamic_stitch_kernel.cc" "host_kernels/identity_kernel.cc" + "host_kernels/reformat_kernel.cc" "graph/passes/stop_gradient_pass.cc" "graph/passes/prevent_gradient_pass.cc" "graph/passes/identity_pass.cc" @@ -602,7 +607,7 @@ set(INFER_SRC_LIST if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) ############ libge_runner.so ############ -add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) +add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} $) target_compile_definitions(ge_runner PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 @@ -647,7 +652,6 @@ target_link_libraries(ge_runner $ ge_memory adump_server - msprofiler static_mmpa -Wl,--no-as-needed graph @@ -656,7 +660,6 @@ target_link_libraries(ge_runner register c_sec slog - msprof runtime resource error_manager @@ -781,7 +784,6 @@ target_link_libraries(opensrc_ascendcl PRIVATE c_sec runtime slog - msprof ascend_hal_stub -Wl,--as-needed -lrt @@ -797,12 +799,10 @@ set_target_properties(opensrc_ascendcl PROPERTIES add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc COMMAND echo "Generating stub files." && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${GE_CODE_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} && mv ge_ir_build.cc stub_ge_ir_build.cc && mv ge_api.cc stub_ge_api.cc - && mv ge_prof.cc stub_ge_prof.cc && echo "Generating stub files end." #WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} #DEPENDS stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} @@ -811,7 +811,6 @@ add_custom_command( add_custom_target(ge_stub DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_ir_build.cc ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_api.cc - ${CMAKE_CURRENT_BINARY_DIR}/stub_ge_prof.cc ) ################################################################## @@ -853,7 +852,6 @@ target_include_directories(atc_stub_ge_compiler PRIVATE ############ stub/libge_runner.so ############ add_library(fwk_stub_ge_runner SHARED stub_ge_api.cc - stub_ge_prof.cc stub_ge_ir_build.cc ) diff --git a/ge/analyzer/analyzer.cc b/ge/analyzer/analyzer.cc old mode 100755 new mode 100644 diff --git a/ge/analyzer/analyzer.h b/ge/analyzer/analyzer.h old mode 100755 new mode 100644 diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 9ecc3016..66958310 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -134,7 +134,7 @@ Status GEInitialize(const std::map &options) { Status GEInitialize(const std::map &options) { std::map str_options; - for (auto & option : options) { + for (auto &option : options) { if (option.first.GetString() == nullptr || option.second.GetString() == nullptr) { GELOGE(FAILED, "GEInitialize options is nullptr."); return FAILED; diff --git a/ge/client/ge_prof.cc b/ge/client/ge_prof.cc deleted file mode 100644 index ede38430..00000000 --- a/ge/client/ge_prof.cc +++ /dev/null @@ -1,369 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ge/ge_prof.h" -#include "ge/ge_api.h" -#include "init/gelib.h" -#include "common/debug/log.h" -#include "framework/common/debug/ge_log.h" -#include "common/profiling/profiling_manager.h" -#include "graph/load/graph_loader.h" -#include "toolchain/prof_acl_api.h" - -using std::map; -using std::string; -using std::vector; - -namespace { -const uint32_t kMaxDeviceNum = 64; -const uint32_t kDeviceListIndex = 3; -const std::string kProfilingInit = "prof_init"; -const std::string kProfilingFinalize = "prof_finalize"; -const std::string kProfilingStart = "prof_start"; -const std::string kProfilingStop = "prof_stop"; -const std::string kDeviceNums = "devNums"; -const std::string kDeviceIdList = "devIdList"; -const std::string kAicoreMetrics = "aicoreMetrics"; - -const std::map kProfAicoreMetricsToString = { - {ge::kAicoreArithmaticThroughput, "AICORE_ARITHMATIC_THROUGHPUT"}, - {ge::kAicorePipeline, "AICORE_PIPELINE"}, - {ge::kAicoreSynchronization, "AICORE_SYNCHRONIZATION"}, - {ge::kAicoreMemory, "AICORE_MEMORY"}, - {ge::kAicoreInternalMemory, "AICORE_INTERNAL_MEMORY"}, - {ge::kAicoreStall, "AICORE_STALL"}}; -} // namespace - -static bool g_graph_prof_init_ = false; -static std::mutex g_prof_mutex_; - -namespace ge { -struct aclgrphProfConfig { - ProfConfig config; -}; - -Status aclgrphProfInit(const char *profiler_path, uint32_t length) { - GELOGT(TRACE_INIT, "Graph prof init start"); - - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); - return FAILED; - } - - std::lock_guard lock(g_prof_mutex_); - if (g_graph_prof_init_) { - GELOGW("Multi graph profiling initializations."); - return GE_PROF_MULTI_INIT; - } - - Status ret = CheckPath(profiler_path, length); - if (ret != SUCCESS) { - GELOGE(ret, "Profiling config path is invalid."); - return ret; - } - // if command mode is set, just return - if (ProfilingManager::Instance().ProfilingOn()) { - GELOGW("Graph prof init failed, cause profiling command pattern is running."); - return GE_PROF_MODE_CONFLICT; - } - - ret = ProfInit(profiler_path); - if (ret != SUCCESS) { - GELOGE(ret, "ProfInit init fail"); - return ret; - } - - GraphLoader graph_loader; - Command command; - command.cmd_params.clear(); - command.cmd_type = kProfilingInit; - command.module_index = PROF_MODEL_LOAD; - ret = graph_loader.CommandHandle(command); - if (ret != SUCCESS) { - GELOGE(ret, "Handle profiling command %s failed, config = %s", kProfilingInit.c_str(), profiler_path); - return ret; - } - if (!g_graph_prof_init_) { - g_graph_prof_init_ = true; - GELOGI("Profiling init successfully."); - } - - GELOGI("Successfully execute GraphProfInit."); - return SUCCESS; -} - -Status aclgrphProfFinalize() { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); - return FAILED; - } - std::lock_guard lock(g_prof_mutex_); - // if command mode is set, just return - if (ProfilingManager::Instance().ProfilingOn()) { - GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); - return GE_PROF_MODE_CONFLICT; - } - - if (!g_graph_prof_init_) { - GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); - return GE_PROF_NOT_INIT; - } - GraphLoader graph_loader; - Command command; - command.cmd_params.clear(); - command.cmd_type = kProfilingFinalize; - Status ret = graph_loader.CommandHandle(command); - if (ret != SUCCESS) { - GELOGE(ret, "Handle profiling command %s failed.", kProfilingFinalize.c_str()); - return ret; - } - - ret = ProfFinalize(); - if (ret != SUCCESS) { - GELOGE(ret, "Finalize profiling failed, result = %d", ret); - } - - if (ret == SUCCESS) { - g_graph_prof_init_ = false; - GELOGI("Successfully execute GraphProfFinalize."); - } - return ret; -} - -bool TransProfConfigToParam(const aclgrphProfConfig *profiler_config, vector &prof_config_params) { - prof_config_params.clear(); - prof_config_params.emplace_back(kDeviceNums); - prof_config_params.emplace_back(std::to_string(profiler_config->config.devNums)); - prof_config_params.emplace_back(kDeviceIdList); - std::string devID = ""; - if (profiler_config->config.devNums == 0) { - GELOGW("The device num is invalid."); - return false; - } - for (uint32_t i = 0; i < profiler_config->config.devNums; i++) { - devID.append(std::to_string(profiler_config->config.devIdList[i])); - if (i != profiler_config->config.devNums - 1) { - devID.append(","); - } - } - - prof_config_params.push_back(devID); - prof_config_params.push_back(kAicoreMetrics); - auto iter = - kProfAicoreMetricsToString.find(static_cast(profiler_config->config.aicoreMetrics)); - if (iter == kProfAicoreMetricsToString.end()) { - GELOGW("The prof aicore metrics is invalid."); - return false; - } - prof_config_params.push_back(iter->second); - return true; -} - -bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) { - if (deviceid_list == nullptr) { - GELOGE(PARAM_INVALID, "deviceIdList is nullptr"); - return false; - } - if (device_nums == 0 || device_nums > kMaxDeviceNum) { - GELOGE(PARAM_INVALID, "The device nums is invalid."); - return false; - } - - // real device num - int32_t dev_count = 0; - rtError_t rt_err = rtGetDeviceCount(&dev_count); - if (rt_err != RT_ERROR_NONE) { - GELOGE(INTERNAL_ERROR, "Get the Device count fail."); - return false; - } - - if (device_nums > static_cast(dev_count)) { - GELOGE(PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count); - return false; - } - - std::unordered_set record; - for (size_t i = 0; i < device_nums; ++i) { - uint32_t dev_id = deviceid_list[i]; - if (dev_id >= static_cast(dev_count)) { - GELOGE(PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count); - return false; - } - if (record.count(dev_id) > 0) { - GELOGE(PARAM_INVALID, "Device id %u is duplicatedly set", dev_id); - return false; - } - record.insert(dev_id); - } - return true; -} - -aclgrphProfConfig *aclgrphProfCreateConfig(uint32_t *deviceid_list, uint32_t device_nums, - ProfilingAicoreMetrics aicore_metrics, ProfAicoreEvents *aicore_events, - uint64_t data_type_config) { - if (!isProfConfigValid(deviceid_list, device_nums)) { - return nullptr; - } - aclgrphProfConfig *config = new (std::nothrow) aclgrphProfConfig(); - if (config == nullptr) { - GELOGE(INTERNAL_ERROR, "new aclgrphProfConfig fail"); - return nullptr; - } - config->config.devNums = device_nums; - if (memcpy_s(config->config.devIdList, sizeof(config->config.devIdList), deviceid_list, - device_nums * sizeof(uint32_t)) != EOK) { - GELOGE(INTERNAL_ERROR, "copy devID failed. size = %u", device_nums); - delete config; - return nullptr; - } - - config->config.aicoreMetrics = static_cast(aicore_metrics); - config->config.dataTypeConfig = data_type_config; - GELOGI("Successfully create prof config."); - return config; -} - -Status aclgrphProfDestroyConfig(aclgrphProfConfig *profiler_config) { - if (profiler_config == nullptr) { - GELOGE(PARAM_INVALID, "destroy profilerConfig failed, profilerConfig must not be nullptr"); - return PARAM_INVALID; - } - - delete profiler_config; - GELOGI("Successfully destroy prof config."); - return SUCCESS; -} - -Status aclgrphProfStart(aclgrphProfConfig *profiler_config) { - if (profiler_config == nullptr) { - GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid."); - return FAILED; - } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); - return FAILED; - } - - std::lock_guard lock(g_prof_mutex_); - // if command mode is set, just return - if (ProfilingManager::Instance().ProfilingOn()) { - GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); - return GE_PROF_MODE_CONFLICT; - } - if (!g_graph_prof_init_) { - GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); - return GE_PROF_NOT_INIT; - } - - Status ret = ProfStartProfiling(&profiler_config->config); - if (ret != SUCCESS) { - GELOGE(ret, "Start profiling failed, prof result = %d", ret); - return FAILED; - } - - std::vector prof_params; - if (!TransProfConfigToParam(profiler_config, prof_params)) { - GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed"); - return PARAM_INVALID; - } - - GraphLoader graph_loader; - Command command; - command.cmd_params.clear(); - command.cmd_type = kProfilingStart; - command.cmd_params = prof_params; - command.module_index = profiler_config->config.dataTypeConfig; - GELOGI("Profiling will start, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), - prof_params[kDeviceListIndex].c_str(), command.module_index); - ret = graph_loader.CommandHandle(command); - if (ret != SUCCESS) { - GELOGE(ret, "Handle profiling command failed"); - return FAILED; - } - - GELOGI("Successfully execute GraphProfStartProfiling."); - - return SUCCESS; -} - -Status aclgrphProfStop(aclgrphProfConfig *profiler_config) { - if (profiler_config == nullptr) { - GELOGE(PARAM_INVALID, "aclgrphProfConfig is invalid."); - return FAILED; - } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge client is not initialized."); - return FAILED; - } - - std::lock_guard lock(g_prof_mutex_); - // if command mode is set, just return - if (ProfilingManager::Instance().ProfilingOn()) { - GELOGW("Graph prof finalize failed, cause profiling command pattern is running."); - return GE_PROF_MODE_CONFLICT; - } - if (!g_graph_prof_init_) { - GELOGE(GE_PROF_NOT_INIT, "Graph not profiling initialize."); - return GE_PROF_NOT_INIT; - } - - for (uint32_t i = 0; i < profiler_config->config.devNums; i++) { - uint64_t data_type_config; - Status status = ProfGetDataTypeConfig(profiler_config->config.devIdList[i], data_type_config); - if (status != SUCCESS) { - GELOGE(status, "Prof get data type config failed, prof result = %d", status); - return status; - } - if (data_type_config != profiler_config->config.dataTypeConfig) { - GELOGE(FAILED, "data type config verify failed"); - return FAILED; - } - } - - std::vector prof_params; - if (!TransProfConfigToParam(profiler_config, prof_params)) { - GELOGE(PARAM_INVALID, "Transfer profilerConfig to string vector failed"); - return PARAM_INVALID; - } - - GraphLoader graph_loader; - Command command; - command.cmd_params.clear(); - command.cmd_type = kProfilingStop; - command.cmd_params = prof_params; - command.module_index = profiler_config->config.dataTypeConfig; - GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), - prof_params[kDeviceListIndex].c_str(), command.module_index); - Status ret = graph_loader.CommandHandle(command); - if (ret != SUCCESS) { - GELOGE(ret, "Handle profiling command failed"); - return FAILED; - } - - ret = ProfStopProfiling(&profiler_config->config); - if (ret != SUCCESS) { - GELOGE(ret, "Stop profiling failed, prof result = %d", ret); - return ret; - } - - GELOGI("Successfully execute GraphProfStopProfiling."); - return SUCCESS; -} -} // namespace ge diff --git a/ge/client/module.mk b/ge/client/module.mk index 6ac69d31..e9d35418 100644 --- a/ge/client/module.mk +++ b/ge/client/module.mk @@ -4,7 +4,6 @@ LOCAL_PATH := $(call my-dir) COMMON_LOCAL_SRC_FILES := \ proto/ge_api.proto \ ge_api.cc \ - ge_prof.cc \ COMMON_LOCAL_C_INCLUDES := \ @@ -69,9 +68,9 @@ LOCAL_SHARED_LIBRARIES := \ libgraph \ libregister \ libge_compiler \ - libge_common \ - libmsprof + libge_common +LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ LOCAL_LDFLAGS := -lrt -ldl @@ -104,8 +103,10 @@ LOCAL_SHARED_LIBRARIES := \ libregister \ libruntime \ libge_compiler \ - libge_common \ - libmsprof + libge_common + + +LOCAL_STATIC_LIBRARIES += libmsprofiler_fwk \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/ge/client/proto/om.proto b/ge/client/proto/om.proto old mode 100755 new mode 100644 diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index aa546c0d..ac230f6d 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -24,6 +24,7 @@ set(SRC_LIST "helper/om_file_helper.cc" "helper/model_helper.cc" "../model/ge_model.cc" + "../model/ge_root_model.cc" "auth/file_saver.cc" "fp16_t.cc" "math/fp16_math.cc" @@ -129,6 +130,7 @@ target_compile_definitions(ge_common_static PRIVATE google=ascend_private $,OS_TYPE=WIN,OS_TYPE=0> $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + LOG_CPP ) target_compile_options(ge_common_static PRIVATE @@ -183,6 +185,7 @@ target_compile_options(ge_common PRIVATE -fvisibility=hidden -O2 -Werror + -Wno-deprecated-declarations ) target_include_directories(ge_common PRIVATE diff --git a/ge/common/auth/file_saver.cc b/ge/common/auth/file_saver.cc old mode 100755 new mode 100644 index 7b41397a..e708653a --- a/ge/common/auth/file_saver.cc +++ b/ge/common/auth/file_saver.cc @@ -54,8 +54,8 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) { Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); mmSsize_t write_count; - uint32_t size_2g = ((uint32_t) 0x1 << 31); - uint32_t size_1g = ((uint32_t) 0x1 << 30); + uint32_t size_2g = 2147483648; // 0x1 << 31 + uint32_t size_1g = 1073741824; // 0x1 << 30 // Write data if (size > size_2g) { auto seek = reinterpret_cast(const_cast(data)); @@ -258,6 +258,65 @@ FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, Mod return SUCCESS; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status +FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas) { + file_header.is_encrypt = ModelEncryptType::UNENCRYPTED; + + const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.", + file_path.c_str(), file_header.length); + return SUCCESS; +} + +Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas) { + + GE_CHK_BOOL_EXEC(model_partition_tables.size() == all_partition_datas.size(), + return PARAM_INVALID, + "model table size %zu does not match partition size %zu", + model_partition_tables.size(), all_partition_datas.size()) + for (size_t index = 0; index < model_partition_tables.size(); ++index) { + auto &cur_partiton_data = all_partition_datas[index]; + auto &cur_model_partition_table = *model_partition_tables[index]; + GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0 + && cur_model_partition_table.num == cur_partiton_data.size(), FAILED, + "Invalid param:partition data size is (%u), model_partition_table.num is (%zu).", + cur_model_partition_table.num, cur_partiton_data.size()); + } + + // Open file + int32_t fd = 0; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED); + Status ret = SUCCESS; + do { + // Write file header + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + WriteData(static_cast(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED; + break); + for (size_t index = 0; index < model_partition_tables.size(); ++index) { + // Write model partition table + auto &cur_tabel = *model_partition_tables[index]; + uint32_t table_size = static_cast(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel)); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + WriteData(static_cast(&cur_tabel), table_size, fd) != SUCCESS, ret = FAILED; break); + // Write partition data + auto &cur_partition_datas = all_partition_datas[index]; + for (const auto &partition_data : cur_partition_datas) { + GELOGI("GC:size[%zu]", partition_data.size); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + WriteData(static_cast(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED; + break); + } + } + } while (0); + // Close file + GE_CHK_BOOL_RET_STATUS(mmClose(fd) == EN_OK, FAILED, "Close file failed."); + return ret; +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, int len) { if (data == nullptr || len <= 0) { diff --git a/ge/common/auth/file_saver.h b/ge/common/auth/file_saver.h index 79e2126e..97fbaae5 100644 --- a/ge/common/auth/file_saver.h +++ b/ge/common/auth/file_saver.h @@ -74,6 +74,10 @@ class FileSaver { ModelPartitionTable &model_partition_table, const std::vector &partition_datas); + static Status SaveToFile(const string &file_path, ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas); + static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, ModelPartitionTable &model_partition_table, const std::vector &partitionDatas, @@ -108,6 +112,9 @@ class FileSaver { static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, ModelPartitionTable &model_partition_table, const std::vector &partition_datas); + static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, + vector &model_partition_tables, + const vector> &all_partition_datas); }; } // namespace ge #endif // GE_COMMON_AUTH_FILE_SAVER_H_ diff --git a/ge/common/base64.h b/ge/common/base64.h index fb6c1870..a537e585 100644 --- a/ge/common/base64.h +++ b/ge/common/base64.h @@ -25,32 +25,38 @@ namespace ge { namespace { -const char* kBase64Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; +const char *kBase64Chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789+/"; const char kEqualSymbol = '='; const size_t kBase64CharsNum = 64; const size_t kThreeByteOneGroup = 3; const size_t kFourByteOneGroup = 4; -} +const size_t kThreeByteOneGroupIndex0 = 0; +const size_t kThreeByteOneGroupIndex1 = 1; +const size_t kThreeByteOneGroupIndex2 = 2; +const size_t kFourByteOneGroupIndex0 = 0; +const size_t kFourByteOneGroupIndex1 = 1; +const size_t kFourByteOneGroupIndex2 = 2; +const size_t kFourByteOneGroupIndex3 = 3; +} // namespace namespace base64 { -static inline bool IsBase64Char(const char &c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} +static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); } static std::string EncodeToBase64(const std::string &raw_data) { size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; - size_t raw_data_index = 0 ; + size_t raw_data_index = 0; size_t encode_data_index = 0; std::string encode_data; encode_data.resize(encode_length); for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { auto char_1 = static_cast(raw_data[raw_data_index]); - auto char_2 = static_cast(raw_data[raw_data_index + 1]); - auto char_3 = static_cast(raw_data[raw_data_index + 2]); + auto char_2 = static_cast(raw_data[raw_data_index + kThreeByteOneGroupIndex1]); + auto char_3 = static_cast(raw_data[raw_data_index + kThreeByteOneGroupIndex2]); encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; @@ -80,8 +86,7 @@ static std::string EncodeToBase64(const std::string &raw_data) { #pragma GCC diagnostic ignored "-Wunused-function" static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { if (base64_data.size() % kFourByteOneGroup != 0) { - GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", - base64_data.size()); + GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size()); return PARAM_INVALID; } decode_data.clear(); @@ -92,10 +97,10 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco return static_cast(std::distance(kBase64Chars, char_pos)) & 0xff; }; - for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += 4) { + for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += kFourByteOneGroup) { for (size_t i = 0; i < kFourByteOneGroup; ++i) { if (base64_data[input_data_index + i] == kEqualSymbol && - input_data_index >= base64_data_len - 4 && i > 1) { + input_data_index >= base64_data_len - kFourByteOneGroup && i > 1) { byte_4[i] = kBase64CharsNum; } else if (IsBase64Char(base64_data[input_data_index + i])) { byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); @@ -104,19 +109,23 @@ static Status DecodeFromBase64(const std::string &base64_data, std::string &deco return PARAM_INVALID; } } - decode_data += static_cast((byte_4[0] << 2u) + ((byte_4[1] & 0x30) >> 4u)); - if (byte_4[2] >= kBase64CharsNum){ + decode_data += + static_cast((byte_4[kFourByteOneGroupIndex0] << 2u) + ((byte_4[kFourByteOneGroupIndex1] & 0x30) >> 4u)); + if (byte_4[kFourByteOneGroupIndex2] >= kBase64CharsNum) { break; - } else if (byte_4[3] >= kBase64CharsNum) { - decode_data += static_cast(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); + } else if (byte_4[kFourByteOneGroupIndex3] >= kBase64CharsNum) { + decode_data += static_cast(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + + ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); break; } - decode_data += static_cast(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); - decode_data += static_cast(((byte_4[2] & 0x03) << 6u) + byte_4[3]); + decode_data += static_cast(((byte_4[kFourByteOneGroupIndex1] & 0x0f) << 4u) + + ((byte_4[kFourByteOneGroupIndex2] & 0x3c) >> 2u)); + decode_data += + static_cast(((byte_4[kFourByteOneGroupIndex2] & 0x03) << 6u) + byte_4[kFourByteOneGroupIndex3]); } return SUCCESS; } #pragma GCC diagnostic pop -} +} // namespace base64 } // namespace ge #endif // GE_COMMON_BASE64_H_ \ No newline at end of file diff --git a/ge/common/context/ctx.cc b/ge/common/context/ctx.cc old mode 100755 new mode 100644 diff --git a/ge/common/cust_aicpu_kernel_store.cc b/ge/common/cust_aicpu_kernel_store.cc old mode 100755 new mode 100644 diff --git a/ge/common/cust_aicpu_kernel_store.h b/ge/common/cust_aicpu_kernel_store.h old mode 100755 new mode 100644 diff --git a/ge/common/debug/memory_dumper.cc b/ge/common/debug/memory_dumper.cc index 872fe1da..527f0bb2 100644 --- a/ge/common/debug/memory_dumper.cc +++ b/ge/common/debug/memory_dumper.cc @@ -139,7 +139,8 @@ int MemoryDumper::OpenFile(const char *filename) { GE_IF_BOOL_EXEC( -1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos); string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, + return kInvalidFd, "Prefix path is too long!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd, "Dir %s does not exit.", prefix_path.c_str()); real_path = std::string(tmp_path) + last_path;) diff --git a/ge/common/debug/memory_dumper.h b/ge/common/debug/memory_dumper.h old mode 100755 new mode 100644 diff --git a/ge/common/dump/dump_op.cc b/ge/common/dump/dump_op.cc old mode 100755 new mode 100644 diff --git a/ge/common/dump/dump_op.h b/ge/common/dump/dump_op.h old mode 100755 new mode 100644 diff --git a/ge/common/fmk_error_codes.cc b/ge/common/fmk_error_codes.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/datatype_transfer.h b/ge/common/formats/format_transfers/datatype_transfer.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h b/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h b/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc old mode 100755 new mode 100644 index ed1c6941..cb528453 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -23,12 +23,30 @@ #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" +#include "framework/common/types.h" #include "graph/utils/type_utils.h" namespace ge { namespace formats { namespace { const int kDimSize4D = 4; + +const size_t kSingleDim = 1; + +const size_t kNdDimIndexN = 0; +const size_t kNdDimIndexH = 1; +const size_t kNdDimIndexW = 2; + +const size_t kDimDValueBNdFNz = 2; // dim d-value between Nd and FractalZz + +const size_t kNdDimCountBackwardsW = 1; +const size_t kNdDimCountBackwardsWH = 2; + +const size_t kFNzDimCountBackwardsW0 = 1; +const size_t kFNzDimCountBackwardsW0H0 = 2; +const size_t kFNzDimCountBackwardsW0H0H1 = 3; +const size_t kFNzDimCountBackwardsW0H0H1W1 = 4; + bool IsDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0; } using ShapeVector = std::vector; @@ -60,14 +78,14 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap auto w0 = GetCubeSizeByDataType(data_type); int64_t h0 = kCubeSize; switch (src_shape.size()) { - case 1: - dst_shape.push_back(Ceil(src_shape[0], w0)); - dst_shape.push_back(1); + case kSingleDim: + dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); + dst_shape.push_back(DIM_DEFAULT_VALUE); dst_shape.push_back(h0); dst_shape.push_back(w0); - hw_shape.push_back(1); - hw_shape.push_back(1); - hw_shape.push_back(src_shape[0]); + hw_shape.push_back(DIM_DEFAULT_VALUE); + hw_shape.push_back(DIM_DEFAULT_VALUE); + hw_shape.push_back(src_shape[kNdDimIndexN]); if (!IsShapeValid(dst_shape)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -76,17 +94,17 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap default: auto size = src_shape.size(); int64_t times = 1; - for (size_t i = 0; i != size - 2; i++) { + for (size_t i = 0; i != size - kDimDValueBNdFNz; i++) { dst_shape.push_back(src_shape[i]); times *= src_shape[i]; } - dst_shape.push_back(Ceil(src_shape[size - 1], w0)); - dst_shape.push_back(Ceil(src_shape[size - 2], h0)); + dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); + dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); dst_shape.push_back(h0); dst_shape.push_back(w0); hw_shape.push_back(times); - hw_shape.push_back(src_shape[size - 2]); - hw_shape.push_back(src_shape[size - 1]); + hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); + hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); if (!IsShapeValid(dst_shape)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -128,16 +146,16 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con } // src&dst_shape can be written as times*H*W & times*W1*H1*H0*W0, respectively. dst_shape_size >= kDimNum4D - auto times = hw_shape.at(0); - auto h = hw_shape.at(1); - auto w = hw_shape.at(2); + auto times = hw_shape.at(kNdDimIndexN); + auto h = hw_shape.at(kNdDimIndexH); + auto w = hw_shape.at(kNdDimIndexW); auto hw = h * w; auto shape_size = args.dst_shape.size(); - auto w1 = args.dst_shape[shape_size - 4]; - auto h1 = args.dst_shape[shape_size - 3]; - auto h0 = args.dst_shape[shape_size - 2]; - auto w0 = args.dst_shape[shape_size - 1]; + auto w1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; + auto h1 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; + auto h0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0H0]; + auto w0 = args.dst_shape[shape_size - kFNzDimCountBackwardsW0]; auto h1h0 = h1 * h0; auto h1h0w0 = h1h0 * w0; auto w1h1h0w0 = w1 * h1h0w0; @@ -198,16 +216,16 @@ Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, con return OUT_OF_MEMORY; } - auto times = dst_hw_shape.at(0); - auto h = dst_hw_shape.at(1); - auto w = dst_hw_shape.at(2); + auto times = dst_hw_shape.at(kNdDimIndexN); + auto h = dst_hw_shape.at(kNdDimIndexH); + auto w = dst_hw_shape.at(kNdDimIndexW); auto hw = h * w; auto shape_size = args.src_shape.size(); - auto w1 = args.src_shape[shape_size - 4]; - auto h1 = args.src_shape[shape_size - 3]; - auto h0 = args.src_shape[shape_size - 2]; - auto w0 = args.src_shape[shape_size - 1]; + auto w1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1W1]; + auto h1 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0H1]; + auto h0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0H0]; + auto w0 = args.src_shape[shape_size - kFNzDimCountBackwardsW0]; auto h1h0 = h1 * h0; auto h1h0w0 = h1h0 * w0; auto w1h1h0w0 = w1 * h1h0w0; diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.h b/ge/common/formats/format_transfers/format_transfer_fractal_nz.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/ge/common/formats/format_transfers/format_transfer_fractal_z.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc old mode 100755 new mode 100644 index d890e681..88603d5c --- a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -23,12 +23,29 @@ #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" +#include "framework/common/types.h" #include "graph/utils/type_utils.h" namespace ge { namespace formats { namespace { const int kDimSize4D = 4; + +const size_t kSingleDim = 1; + +const size_t kNdDimIndexN = 0; +const size_t kNdDimIndexH = 1; +const size_t kNdDimIndexW = 2; + +const size_t kDimDValueBNdFZz = 2; // dim d-value between Nd and FractalZz + +const size_t kNdDimCountBackwardsW = 1; +const size_t kNdDimCountBackwardsWH = 2; + +const size_t kFZzDimCountBackwardsW0 = 1; +const size_t kFZzDimCountBackwardsW0H0 = 2; +const size_t kFZzDimCountBackwardsW0H0W1 = 3; +const size_t kFZzDimCountBackwardsW0H0W1H1 = 4; bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } using ShapeVector = std::vector; @@ -40,8 +57,8 @@ bool CheckShape(Format format, const ShapeVector &shape) { case FORMAT_NHWC: return CheckShapeValid(shape, kDimSize4D); default: - std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + - " and FORMAT_FRACTAL_ZZ is not supported."; + std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + + " and FORMAT_FRACTAL_ZZ is not supported."; GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); return false; } @@ -60,14 +77,14 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap auto w0 = GetCubeSizeByDataType(data_type); auto h0 = GetCubeSizeByDataType(data_type); switch (src_shape.size()) { - case 1: - dst_shape.push_back(1); - dst_shape.push_back(Ceil(src_shape[0], w0)); + case kSingleDim: + dst_shape.push_back(DIM_DEFAULT_VALUE); + dst_shape.push_back(Ceil(src_shape[kNdDimIndexN], w0)); dst_shape.push_back(h0); dst_shape.push_back(w0); - hw_shape.push_back(1); - hw_shape.push_back(1); - hw_shape.push_back(src_shape[0]); + hw_shape.push_back(DIM_DEFAULT_VALUE); + hw_shape.push_back(DIM_DEFAULT_VALUE); + hw_shape.push_back(src_shape[kNdDimIndexN]); if (!IsShapeValid(dst_shape)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -76,17 +93,17 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap default: auto size = src_shape.size(); int64_t times = 1; - for (size_t i = 0; i != size - 2; i++) { + for (size_t i = 0; i != size - kDimDValueBNdFZz; i++) { dst_shape.push_back(src_shape[i]); times *= src_shape[i]; } - dst_shape.push_back(Ceil(src_shape[size - 2], h0)); - dst_shape.push_back(Ceil(src_shape[size - 1], w0)); + dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsWH], h0)); + dst_shape.push_back(Ceil(src_shape[size - kNdDimCountBackwardsW], w0)); dst_shape.push_back(h0); dst_shape.push_back(w0); hw_shape.push_back(times); - hw_shape.push_back(src_shape[size - 2]); - hw_shape.push_back(src_shape[size - 1]); + hw_shape.push_back(src_shape[size - kNdDimCountBackwardsWH]); + hw_shape.push_back(src_shape[size - kNdDimCountBackwardsW]); if (!IsShapeValid(dst_shape)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -127,16 +144,16 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con return OUT_OF_MEMORY; } // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D - auto times = hw_shape.at(0); - auto h = hw_shape.at(1); - auto w = hw_shape.at(2); + auto times = hw_shape.at(kNdDimIndexN); + auto h = hw_shape.at(kNdDimIndexH); + auto w = hw_shape.at(kNdDimIndexW); auto hw = h * w; auto shape_size = args.dst_shape.size(); - auto h1 = args.dst_shape[shape_size - 4]; - auto w1 = args.dst_shape[shape_size - 3]; - auto h0 = args.dst_shape[shape_size - 2]; - auto w0 = args.dst_shape[shape_size - 1]; + auto h1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; + auto w1 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; + auto h0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0H0]; + auto w0 = args.dst_shape[shape_size - kFZzDimCountBackwardsW0]; auto h0w0 = h0 * w0; auto w1h0w0 = w1 * h0w0; auto h1w1h0w0 = h1 * w1h0w0; @@ -155,8 +172,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con auto src_offset = (src_h_head + w1_idx * w0) * size; auto dst_offset = (h0_head + w1_idx * h0w0) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) - ? dst_size - dst_offset - : static_cast(SECUREC_MEM_MAX_LEN); + ? dst_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -171,8 +188,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con auto src_offset = (src_h_head + src_w_idx) * size; auto dst_offset = (w0_head + w0_idx) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) - ? dst_size - dst_offset - : static_cast(SECUREC_MEM_MAX_LEN); + ? dst_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { @@ -205,16 +222,16 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con } // The src&dst_shape can be written as times*H*W & times*H1*W1*H0*W0, respectively. dst_shape_size >= kDimNum4D - auto times = dst_hw_shape.at(0); - auto h = dst_hw_shape.at(1); - auto w = dst_hw_shape.at(2); + auto times = dst_hw_shape.at(kNdDimIndexN); + auto h = dst_hw_shape.at(kNdDimIndexH); + auto w = dst_hw_shape.at(kNdDimIndexW); auto hw = h * w; auto shape_size = args.src_shape.size(); - auto h1 = args.src_shape[shape_size - 4]; - auto w1 = args.src_shape[shape_size - 3]; - auto h0 = args.src_shape[shape_size - 2]; - auto w0 = args.src_shape[shape_size - 1]; + auto h1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1H1]; + auto w1 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0W1]; + auto h0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0H0]; + auto w0 = args.src_shape[shape_size - kFZzDimCountBackwardsW0]; auto h0w0 = h0 * w0; auto w1h0w0 = w1 * h0w0; auto h1w1h0w0 = h1 * w1h0w0; @@ -233,8 +250,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con auto src_offset = (h0_head + w1_idx * h0w0) * size; auto dst_offset = (dst_h_head + w1_idx * w0) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) - ? dst_size - dst_offset - : static_cast(SECUREC_MEM_MAX_LEN); + ? dst_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size * w0)); if (ret != EOK) { @@ -249,8 +266,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con auto dst_w_idx = w1_head + w0_idx; auto dst_offset = (dst_h_head + dst_w_idx) * size; auto protected_size = dst_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) - ? dst_size - dst_offset - : static_cast(SECUREC_MEM_MAX_LEN); + ? dst_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); auto ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), args.data + src_offset, static_cast(size)); if (ret != EOK) { diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.h b/ge/common/formats/format_transfers/format_transfer_fractal_zz.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h b/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h b/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h b/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc index a66aeeb4..49b19f46 100644 --- a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc +++ b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc @@ -35,7 +35,6 @@ * Padding to (N, ceil(Z/16)*16) * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) */ - namespace ge { namespace formats { namespace { diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h b/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc old mode 100755 new mode 100644 index e623d9e7..9be74b1f --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -19,6 +19,7 @@ #include #include +#include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" @@ -29,21 +30,21 @@ namespace formats { namespace { std::map>> perm_args{ {FORMAT_NCHW, - {{FORMAT_NHWC, std::vector({0, 2, 3, 1})}, - {FORMAT_HWCN, std::vector({2, 3, 1, 0})}, - {FORMAT_CHWN, std::vector({1, 2, 3, 0})}}}, + {{FORMAT_NHWC, std::vector({kNchwN, kNchwH, kNchwW, kNchwC})}, + {FORMAT_HWCN, std::vector({kNchwH, kNchwW, kNchwC, kNchwN})}, + {FORMAT_CHWN, std::vector({kNchwC, kNchwH, kNchwW, kNchwN})}}}, {FORMAT_NHWC, - {{FORMAT_NCHW, std::vector({0, 3, 1, 2})}, - {FORMAT_CHWN, std::vector({3, 1, 2, 0})}, - {FORMAT_HWCN, std::vector({1, 2, 3, 0})}}}, + {{FORMAT_NCHW, std::vector({kNhwcN, kNhwcC, kNhwcH, kNhwcW})}, + {FORMAT_CHWN, std::vector({kNhwcC, kNhwcH, kNhwcW, kNhwcN})}, + {FORMAT_HWCN, std::vector({kNhwcH, kNhwcW, kNhwcC, kNhwcN})}}}, {FORMAT_HWCN, - {{FORMAT_NCHW, std::vector({3, 2, 0, 1})}, - {FORMAT_NHWC, std::vector({3, 0, 1, 2})}, - {FORMAT_CHWN, std::vector({2, 0, 1, 3})}}}, + {{FORMAT_NCHW, std::vector({kHwcnN, kHwcnC, kHwcnH, kHwcnW})}, + {FORMAT_NHWC, std::vector({kHwcnN, kHwcnH, kHwcnW, kHwcnC})}, + {FORMAT_CHWN, std::vector({kHwcnC, kHwcnH, kHwcnW, kHwcnN})}}}, {FORMAT_CHWN, - {{FORMAT_NCHW, std::vector({3, 0, 1, 2})}, - {FORMAT_NHWC, std::vector({3, 1, 2, 0})}, - {FORMAT_HWCN, std::vector({1, 2, 0, 3})}}}, + {{FORMAT_NCHW, std::vector({kChwnN, kChwnC, kChwnH, kChwnW})}, + {FORMAT_NHWC, std::vector({kChwnN, kChwnH, kChwnW, kChwnC})}, + {FORMAT_HWCN, std::vector({kChwnH, kChwnW, kChwnC, kChwnN})}}}, }; bool IsShapeArgValid(const std::vector &src_shape, const std::vector &perm_arg) { diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.h b/ge/common/formats/format_transfers/format_transfer_transpose.h old mode 100755 new mode 100644 diff --git a/ge/common/formats/formats.cc b/ge/common/formats/formats.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/utils/formats_definitions.h b/ge/common/formats/utils/formats_definitions.h old mode 100755 new mode 100644 index 7f873f1b..25f36d6a --- a/ge/common/formats/utils/formats_definitions.h +++ b/ge/common/formats/utils/formats_definitions.h @@ -23,6 +23,7 @@ static const int kCubeSize = 16; static const int kNiSize = 16; static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; + enum NchwDimIndex { kNchwN, kNchwC, @@ -47,6 +48,14 @@ enum HwcnDimIndex { kHwcnDimsNum }; +enum ChwnDimIndex { + kChwnC, + kChwnH, + kChwnW, + kChwnN, + kChwnDimsNum +}; + enum Nc1hwc0DimIndex { kNc1hwc0N, kNc1hwc0C1, diff --git a/ge/common/formats/utils/formats_trans_utils.cc b/ge/common/formats/utils/formats_trans_utils.cc old mode 100755 new mode 100644 diff --git a/ge/common/formats/utils/formats_trans_utils.h b/ge/common/formats/utils/formats_trans_utils.h old mode 100755 new mode 100644 diff --git a/ge/common/fp16_t.cc b/ge/common/fp16_t.cc old mode 100755 new mode 100644 diff --git a/ge/common/fp16_t.h b/ge/common/fp16_t.h old mode 100755 new mode 100644 diff --git a/ge/common/ge/datatype_util.cc b/ge/common/ge/datatype_util.cc old mode 100755 new mode 100644 diff --git a/ge/common/ge/plugin_manager.cc b/ge/common/ge/plugin_manager.cc index 7bb1310c..75a36d99 100644 --- a/ge/common/ge/plugin_manager.cc +++ b/ge/common/ge/plugin_manager.cc @@ -123,7 +123,10 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec if (handle == nullptr) { const char *error = mmDlerror(); GE_IF_BOOL_EXEC(error == nullptr, error = ""); - GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen %s!", error); + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"mmDlopen", "shared library path is " + FmtToStr(file_path_dlopen) + ". Errormessage" + FmtToStr(error)}); + GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen the shared library path[%s]. Errormessage[%s]!", + file_path_dlopen.c_str(), error); continue; } @@ -132,6 +135,9 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec for (const auto &func_name : func_check_list) { auto real_fn = (void (*)())mmDlsym(handle, const_cast(func_name.c_str())); if (real_fn == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"mmDlsym", FmtToStr(func_name) + " is skipped since function" + + FmtToStr(func_name) + " is not existed!"}); GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(), func_name.c_str()); is_valid = false; diff --git a/ge/common/ge/plugin_manager.h b/ge/common/ge/plugin_manager.h old mode 100755 new mode 100644 diff --git a/ge/common/ge/tbe_plugin_manager.cc b/ge/common/ge/tbe_plugin_manager.cc old mode 100755 new mode 100644 index b91f1204..44199c32 --- a/ge/common/ge/tbe_plugin_manager.cc +++ b/ge/common/ge/tbe_plugin_manager.cc @@ -37,6 +37,8 @@ #include "graph/utils/type_utils.h" namespace ge { +const int kBaseInt = 10; + std::map TBEPluginManager::options_ = {}; // Get Singleton Instance @@ -155,7 +157,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) { domi::FrameworkType type = domi::TENSORFLOW; auto it = options_.find(FRAMEWORK_TYPE); if (it != options_.end()) { - type = static_cast(std::strtol(it->second.c_str(), nullptr, 10)); + type = static_cast(std::strtol(it->second.c_str(), nullptr, kBaseInt)); } fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); GELOGI("Framework type is %s.", fmk_type.c_str()); diff --git a/ge/common/ge/tbe_plugin_manager.h b/ge/common/ge/tbe_plugin_manager.h old mode 100755 new mode 100644 diff --git a/ge/common/ge_common.mk b/ge/common/ge_common.mk old mode 100755 new mode 100644 index 3fffd203..e28090ad --- a/ge/common/ge_common.mk +++ b/ge/common/ge_common.mk @@ -7,6 +7,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ helper/om_file_helper.cc \ helper/model_helper.cc \ ../model/ge_model.cc \ + ../model/ge_root_model.cc \ auth/file_saver.cc \ fp16_t.cc \ math/fp16_math.cc \ diff --git a/ge/common/ge_format_util.cc b/ge/common/ge_format_util.cc old mode 100755 new mode 100644 diff --git a/ge/common/helper/model_cache_helper.cc b/ge/common/helper/model_cache_helper.cc old mode 100755 new mode 100644 diff --git a/ge/common/helper/model_cache_helper.h b/ge/common/helper/model_cache_helper.h old mode 100755 new mode 100644 diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index 6f201461..aacef88c 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -32,6 +32,7 @@ using domi::ModelTaskDef; namespace { const int64_t kOriginalOmPartitionNum = 1; +const uint32_t kStatiOmFileModelNum = 1; } @@ -39,7 +40,7 @@ namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } Status ModelHelper::SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, - const uint8_t *data, size_t size) { + const uint8_t *data, size_t size, size_t model_index) { if (size < 1 || size > UINT32_MAX) { GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); if (size > UINT32_MAX) { @@ -68,25 +69,16 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr &om_fil partition_model.data = const_cast(data); partition_model.size = static_cast(size); partition_model.type = type; - if (om_file_save_helper->AddPartition(partition_model) != SUCCESS) { + if (om_file_save_helper->AddPartition(partition_model, model_index) != SUCCESS) { GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu", size); return PARAM_INVALID; } return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, - const SaveParam &save_param, - const std::string &output_file, - ModelBufferData& model) { - if (output_file.empty()) { - GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); - return FAILED; - } - GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED); - std::shared_ptr om_file_save_helper = ge::MakeShared(); - GE_CHECK_NOTNULL(om_file_save_helper); +Status ModelHelper::SaveModelDef(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, ge::Buffer &model_buffer, size_t model_index) { ModelPtr model_tmp = ge::MakeShared(ge_model->GetName(), ge_model->GetPlatformVersion()); if (model_tmp == nullptr) { GELOGE(FAILED, "Create Model %s Ptr failed", ge_model->GetName().c_str()); @@ -96,16 +88,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod model_tmp->SetVersion(ge_model->GetVersion()); model_tmp->SetAttr(ge_model->MutableAttrMap()); - ge::Buffer model_buffer; + (void)model_tmp->Save(model_buffer); GELOGD("MODEL_DEF size is %zu", model_buffer.GetSize()); if (model_buffer.GetSize() > 0) { if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(), - model_buffer.GetSize()) != SUCCESS) { + model_buffer.GetSize(), model_index) != SUCCESS) { GELOGE(PARAM_INVALID, "Add model graph partition failed"); return PARAM_INVALID; } } + return SUCCESS; +} + +Status ModelHelper::SaveModelWeights(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, size_t model_index) { auto ge_model_weight = ge_model->GetWeight(); GELOGD("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); // weight is not necessary @@ -113,31 +110,43 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, ge_model_weight.GetData(), - ge_model_weight.GetSize()), "Add weight partition failed"); + ge_model_weight.GetSize(), model_index), "Add weight partition failed"); } + return SUCCESS; +} +Status ModelHelper::SaveModelTbeKernel(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, size_t model_index) { TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); GELOGD("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); if (tbe_kernel_store.DataSize() > 0) { - GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, - ModelPartitionType::TBE_KERNELS, - tbe_kernel_store.Data(), - tbe_kernel_store.DataSize()), "Add tbe kernel partition failed"); + GE_CHK_STATUS_RET( + SaveModelPartition(om_file_save_helper, ModelPartitionType::TBE_KERNELS, + ge_model->GetTBEKernelStore().Data(), ge_model->GetTBEKernelStore().DataSize(), + model_index), "Add tbe kernel partition failed"); } - // no need to check value, DATA->NetOutput (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); + return SUCCESS; +} + +Status ModelHelper::SaveModelCustAICPU(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, size_t model_index) { CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); GELOGD("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); if (cust_aicpu_kernel_store.DataSize() > 0) { GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::CUST_AICPU_KERNELS, - cust_aicpu_kernel_store.Data(), - cust_aicpu_kernel_store.DataSize()), + ge_model->GetCustAICPUKernelStore().Data(), + cust_aicpu_kernel_store.DataSize(), model_index), "Add cust aicpu kernel partition failed"); } + return SUCCESS; +} +Status ModelHelper::SaveModelTaskDef(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, ge::Buffer &task_buffer, size_t model_index) { std::shared_ptr model_task_def = ge_model->GetModelTaskDefPtr(); if (model_task_def == nullptr) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create model task def ptr failed"); @@ -146,9 +155,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod size_t partition_task_size = model_task_def->ByteSizeLong(); GE_IF_BOOL_EXEC(partition_task_size == 0 || partition_task_size > INT_MAX, GELOGE(FAILED, "Model_def's byte size (%zu) is invalid!", partition_task_size); - return FAILED); + return FAILED); - ge::Buffer task_buffer(partition_task_size); + task_buffer = ge::Buffer(partition_task_size); if (task_buffer.GetSize() == 0) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc model task def buffer failed"); return ACL_ERROR_GE_MEMORY_ALLOCATION; @@ -159,21 +168,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod GELOGD("TASK_INFO size is %zu", partition_task_size); if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TASK_INFO, task_buffer.GetData(), - partition_task_size) != SUCCESS) { + partition_task_size, model_index) != SUCCESS) { GELOGE(PARAM_INVALID, "Add model task def partition failed"); return PARAM_INVALID; } + return SUCCESS; +} + +Status ModelHelper::SaveModelHeader(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, size_t model_num) { // Save target/version to model_header ModelFileHeader &model_header = om_file_save_helper->GetModelFileHeader(); model_header.platform_type = ge_model->GetPlatformType(); model_header.om_ir_version = ge_model->GetVersion(); + model_header.model_num = model_num; std::string platform_version = ge_model->GetPlatformVersion(); errno_t err; err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(), platform_version.size() + 1); if (err != EOK) { - GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelHelper SaveModel failed while allocating memory for platform_version."); + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, + "ModelHelper SaveModel failed while allocating memory for platform_version."); return ACL_ERROR_GE_MEMORY_ALLOCATION; } string version = reinterpret_cast(model_header.platform_version); @@ -188,8 +204,142 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod } string model_name = reinterpret_cast(model_header.name); GELOGD("Model name save:%s", model_name.c_str()); + return SUCCESS; +} + +Status ModelHelper::SaveAllModelPartiton(std::shared_ptr& om_file_save_helper, + const GeModelPtr &ge_model, ge::Buffer &model_buffer, + ge::Buffer &task_buffer, size_t model_index) { + if (SaveModelDef(om_file_save_helper, ge_model, model_buffer, model_index) != SUCCESS) { + GELOGE(FAILED, "save model def failed"); + return FAILED; + } + + if (SaveModelWeights(om_file_save_helper, ge_model, model_index) != SUCCESS) { + GELOGE(FAILED, "save model weights failed"); + return FAILED; + } + + if (SaveModelTbeKernel(om_file_save_helper, ge_model, model_index) != SUCCESS) { + GELOGE(FAILED, "save model tbe kernel failed"); + return FAILED; + } + + if (SaveModelCustAICPU(om_file_save_helper, ge_model, model_index) != SUCCESS) { + GELOGE(FAILED, "save model cust ai cpu failed"); + return FAILED; + } + + + if (SaveModelTaskDef(om_file_save_helper, ge_model, task_buffer, model_index) != SUCCESS) { + GELOGE(FAILED, "save task def failed"); + return FAILED; + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmModel(const GeModelPtr &ge_model, + const SaveParam &save_param, + const std::string &output_file, + ModelBufferData& model) { + if (output_file.empty()) { + GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); + return FAILED; + } - Status ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_); + GE_IF_BOOL_EXEC(ge_model == nullptr, GELOGE(FAILED, "Ge_model is nullptr"); return FAILED); + std::shared_ptr om_file_save_helper = ge::MakeShared(); + GE_CHECK_NOTNULL(om_file_save_helper); + ge::Buffer model_buffer; + ge::Buffer task_buffer; + + auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffer, task_buffer); + if (ret != SUCCESS) { + GELOGE(ret, "save all model partition failed"); + return ret; + } + + ret = SaveModelHeader(om_file_save_helper, ge_model); + if (ret != SUCCESS) { + GELOGE(ret, "save model header failed"); + return ret; + } + + ret = om_file_save_helper->SaveModel(save_param, output_file.c_str(), model, is_offline_); + if (ret != SUCCESS) { + GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); + return ret; + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmRootModel( + const GeRootModelPtr &ge_root_model, + const SaveParam &save_param, + const std::string &output_file, + ModelBufferData& model, + bool is_unknown_shape) { + + GE_CHECK_NOTNULL(ge_root_model); + GE_IF_BOOL_EXEC(ge_root_model == nullptr, GELOGE(FAILED, "Ge_root_model is nullptr"); return FAILED); + + auto &name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GE_IF_BOOL_EXEC(name_to_ge_model.empty(), GELOGE(FAILED, "Ge_root_model has no sub model"); return FAILED); + GE_IF_BOOL_EXEC(output_file.empty(), + GELOGE(FAILED, "GraphBuilder SaveModel received invalid file name prefix"); + return FAILED); + + if (!is_unknown_shape) { + auto &model_root = name_to_ge_model.begin()->second; + return SaveToOmModel(model_root, save_param, output_file, model); + } + + std::shared_ptr om_file_save_helper = ge::MakeShared(); + GE_CHECK_NOTNULL(om_file_save_helper); + + auto &first_ge_model = name_to_ge_model.at(ge_root_model->GetRootGraph()->GetName()); + + // ge root model must be the first to be loaded + vector model_names{ge_root_model->GetRootGraph()->GetName()}; + for (auto &item : name_to_ge_model) { + if (item.first != model_names.front()) { + model_names.emplace_back(item.first); + } + } + + vector model_buffers(model_names.size()); + vector task_buffers(model_names.size()); + + size_t cur_index = 0; + + if (model_names.size() > 1) { + GELOGD("only save first model MODEL_DEF"); + if (SaveModelDef(om_file_save_helper, first_ge_model, model_buffers[cur_index], cur_index) != SUCCESS) { + GELOGE(FAILED, "save model def failed"); + return FAILED; + } + ++cur_index; + } + + for (; cur_index < model_names.size(); ++cur_index) { + auto model_name = model_names[cur_index]; + GELOGD("cur model %s index is %zu", model_name.c_str(), cur_index); + const GeModelPtr &ge_model = name_to_ge_model.at(model_name); + auto ret = SaveAllModelPartiton(om_file_save_helper, ge_model, model_buffers[cur_index], + task_buffers[cur_index], cur_index); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Save model %s failed", model_name.c_str()); + return INTERNAL_ERROR; + } + } + + auto ret = SaveModelHeader(om_file_save_helper, first_ge_model, model_names.size()); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Save model %s header failed", first_ge_model->GetName().c_str()); + return INTERNAL_ERROR; + } + + ret = om_file_save_helper->SaveRootModel(save_param, output_file.c_str(), model, is_offline_); if (ret != SUCCESS) { GELOGE(FAILED, "OmFileSaveHelper SaveModel return fail."); return FAILED; @@ -288,7 +438,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c } file_header_ = reinterpret_cast(model_data.model_data); - OmFileLoadHelper om_load_helper; status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); if (status != SUCCESS) { @@ -310,7 +459,61 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c GELOGE(status, "GenerateGeModel failed"); return status; } + GELOGD("in ModelHelper::LoadModel, is_assign_model_ is setted to true!"); + is_assign_model_ = true; + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { + if (model_data.model_data == nullptr || model_data.model_len == 0) { + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); + return GE_EXEC_MODEL_DATA_SIZE_INVALID; + } + + if (is_assign_model_) { + GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!"); + return GE_EXEC_LOAD_MODEL_REPEATED; + } + if (ReleaseLocalModelData() != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed."); + return INTERNAL_ERROR; + } + + Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_); + if (status != SUCCESS) { + GELOGE(status, "Parse model content failed!"); + return status; + } + + file_header_ = reinterpret_cast(model_data.model_data); + + //model verison 1.0 file header does not have model_num member + is_unknown_shape_model_ = file_header_->version >= ge::MODEL_VERSION && + file_header_->model_num > kStatiOmFileModelNum; + GELOGD("cur om model is ge root model or no %d, model version %zu", is_unknown_shape_model_, file_header_->version); + + OmFileLoadHelper om_load_helper; + if (is_unknown_shape_model_) { + auto model_num = file_header_->model_num; + status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_, model_num); + } else { + status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); + } + if (status != SUCCESS) { + GELOGE(status, "Om_load_helper init failed"); + model_addr_tmp_ = nullptr; + return status; + } + // Encrypt model need to del temp model/no encrypt model don't need to del model + model_addr_tmp_ = nullptr; + + status = GenerateGeRootModel(om_load_helper); + if (status != SUCCESS) { + GELOGE(status, "GenerateGeRootModel failed"); + return status; + } + GELOGD("in ModelHelper::LoadRootModel, is_assign_model_ is setted to true!"); is_assign_model_ = true; return SUCCESS; } @@ -341,6 +544,61 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) { return SUCCESS; } +Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) { + GELOGD("Begin to generate ge root model"); + root_model_ = ge::MakeShared(); + GE_CHECK_NOTNULL(root_model_); + if (!is_unknown_shape_model_) { + if (GenerateGeModel(om_load_helper) != SUCCESS) { + GELOGE(FAILED, "GenerateGeModel failed"); + return FAILED; + } + GE_CHECK_NOTNULL(model_); + root_model_->SetRootGraph(GraphUtils::GetComputeGraph(model_->GetGraph())); + return SUCCESS; + } + + bool is_first_model = true; + for (size_t mode_index = 0; mode_index < file_header_->model_num; ++mode_index) { + GeModelPtr cur_model = ge::MakeShared(); + Status ret = LoadModelData(om_load_helper, cur_model, mode_index); + if (ret != SUCCESS) { + return GE_EXEC_LOAD_MODEL_PARTITION_FAILED; + } + + if (is_first_model) { + is_first_model = false; + root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph())); + root_model_->SetModelId(cur_model->GetModelId()); + model_ = cur_model; + continue; + } + + ret = LoadWeights(om_load_helper, cur_model, mode_index); + if (ret != SUCCESS) { + return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED; + } + + ret = LoadTBEKernelStore(om_load_helper, cur_model, mode_index); + if (ret != SUCCESS) { + return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; + } + + ret = LoadCustAICPUKernelStore(om_load_helper, cur_model, mode_index); + if (ret != SUCCESS) { + return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; + } + + ret = LoadTask(om_load_helper, cur_model, mode_index); + if (ret != SUCCESS) { + return GE_EXEC_LOAD_TASK_PARTITION_FAILED; + } + root_model_->SetSubgraphInstanceNameToModel(cur_model->GetName(), cur_model); + } + + return SUCCESS; +} + Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper) { ModelPartition partition_model_def; // no need to check value, DATA->NetOutput @@ -366,6 +624,28 @@ void ModelHelper::SetModelToGeModel(ge::Model &model) { model_->SetAttr(model.MutableAttrMap()); } +Status ModelHelper::LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { + ModelPartition partition_model_def; + // no need to check value, DATA->NetOutput + om_load_helper.GetModelPartition(ModelPartitionType::MODEL_DEF, partition_model_def, mode_index); + GELOGD("Model_def partition addr:%p,size:%u", partition_model_def.data, partition_model_def.size); + + ge::Model model; + if (ge::Model::Load(partition_model_def.data, partition_model_def.size, model) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Load model failed."); + return INTERNAL_ERROR; + } + + cur_model->SetGraph(model.GetGraph()); + cur_model->SetName(model.GetName()); + cur_model->SetVersion(model.GetVersion()); + cur_model->SetPlatformVersion(model.GetPlatformVersion()); + cur_model->SetAttr(model.MutableAttrMap()); + + return SUCCESS; +} + + Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { ModelPartition partition; if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition) != SUCCESS) { @@ -379,6 +659,19 @@ Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper) { return SUCCESS; } +Status ModelHelper::LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { + ModelPartition partition; + if (om_load_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition, mode_index) != SUCCESS) { + GELOGE(FAILED, "Get weight model partition failed."); + return FAILED; + } + ge::Buffer weight = ge::Buffer::CopyFrom(partition.data, partition.size); + cur_model->SetWeight(weight); + + GELOGD("GetWeight size:%u", partition.size); + return SUCCESS; +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper) { ModelPartition task_partition; if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { @@ -398,6 +691,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(Om return SUCCESS; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadTask(OmFileLoadHelper &om_load_helper, + GeModelPtr &cur_model, + size_t mode_index) { + ModelPartition task_partition; + if (om_load_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition, mode_index) != SUCCESS) { + GELOGE(FAILED, "Get task model partition failed."); + return FAILED; + } + std::shared_ptr task = ge::MakeShared(); + GE_CHECK_NOTNULL(task); + if (task_partition.size != 0) { + if (!ReadProtoFromArray(task_partition.data, task_partition.size, task.get())) { + GELOGE(INTERNAL_ERROR, "ReadProtoFromArray failed."); + return INTERNAL_ERROR; + } + GELOGD("TASK_INFO op_size:%zu, stream_num:%u", task->op().size(), task->stream_num()); + } + cur_model->SetModelTaskDef(task); + return SUCCESS; +} + Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { // Load tbe kernels ModelPartition partition_kernel_def; @@ -414,6 +728,23 @@ Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { return SUCCESS; } +Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index) { + // Load tbe kernels + ModelPartition partition_kernel_def; + TBEKernelStore kernel_store; + if (om_load_helper.GetModelPartition(ModelPartitionType::TBE_KERNELS, partition_kernel_def, mode_index) == + SUCCESS) { + GELOGD("Kernels partition size:%u", partition_kernel_def.size); + if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { + GELOGD("Load tbe kernels success"); + } else { + GELOGW("Load tbe kernels failed"); + } + } + cur_model->SetTBEKernelStore(kernel_store); + return SUCCESS; +} + Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { // Load cust aicpu kernels ModelPartition partition_kernel_def; @@ -421,19 +752,39 @@ Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) { GELOGD("Kernels partition size:%u", partition_kernel_def.size); if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { - GELOGI("Load cust aicpu kernels success"); + GELOGD("Load cust aicpu kernels success"); + } else { + GELOGW("Load cust aicpu kernels failed"); } } model_->SetCustAICPUKernelStore(kernel_store); return SUCCESS; } +Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, + GeModelPtr &cur_model, size_t mode_index) { + // Load cust aicpu kernels + ModelPartition partition_kernel_def; + CustAICPUKernelStore kernel_store; + if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def, mode_index) + == SUCCESS) { + GELOGD("Kernels partition size:%u", partition_kernel_def.size); + if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { + GELOGD("Load cust aicpu kernels success"); + } else { + GELOGW("Load cust aicpu kernels failed"); + } + } + cur_model->SetCustAICPUKernelStore(kernel_store); + return SUCCESS; +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { if (model_ != nullptr) { return model_; } - GELOGI("Model has not been loaded!"); + GELOGD("Model has not been loaded!"); std::shared_ptr out_model = ge::MakeShared(); if (out_model == nullptr) { return nullptr; @@ -441,6 +792,20 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo return out_model; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeRootModelPtr ModelHelper::GetGeRootModel() { + if (root_model_ != nullptr) { + return root_model_; + } + + GELOGD("Model has not been loaded!"); + std::shared_ptr out_model = ge::MakeShared(); + if (out_model == nullptr) { + return nullptr; + } + return out_model; +} + + Status ModelHelper::ReleaseLocalModelData() noexcept { Status result = SUCCESS; if (model_addr_tmp_ != nullptr) { diff --git a/ge/common/helper/om_file_helper.cc b/ge/common/helper/om_file_helper.cc index ce88cd08..d1c52b13 100644 --- a/ge/common/helper/om_file_helper.cc +++ b/ge/common/helper/om_file_helper.cc @@ -52,6 +52,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u return SUCCESS; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, + uint32_t model_data_size, + uint32_t model_num) { + Status status = LoadModelPartitionTable(model_data, model_data_size, model_num); + if (status != SUCCESS) { + return status; + } + is_inited_ = true; + return SUCCESS; +} + // Use both FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, ModelPartition &partition) { @@ -79,6 +90,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod return SUCCESS; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, + ModelPartition &partition, + size_t model_index) { + if (!is_inited_) { + GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!"); + return PARAM_INVALID; + } + if (model_index >= model_contexts_.size()) { + GELOGE(PARAM_INVALID, "cur index : %zu, model_contexts size:%zu", model_index, model_contexts_.size()); + return PARAM_INVALID; + } + auto &cur_ctx = model_contexts_[model_index]; + bool found = false; + for (ModelPartition &part : cur_ctx.partition_datas_) { + if (part.type == type) { + partition = part; + found = true; + break; + } + } + + if (!found) { + if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA && + type != ModelPartitionType::CUST_AICPU_KERNELS) { + GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast(type)); + return FAILED; + } + } + return SUCCESS; +} + Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { // Parameter validity check if (model.model_data == nullptr) { @@ -138,7 +180,8 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint context_.partition_datas_.push_back(partition); if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { - GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", + GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, + "The partition size %zu is greater than the model data size %u.", partition.size + mem_offset, model_data_size); return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; } @@ -148,6 +191,61 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint return SUCCESS; } +Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) { + if (model_data == nullptr) { + GELOGE(PARAM_INVALID, "Param model_data must not be null!"); + return PARAM_INVALID; + } + + uint32_t cur_offset = 0; + for (uint32_t index = 0; index < model_num; ++index) { + // Init partition table + auto partition_table = reinterpret_cast(model_data + cur_offset); + size_t partition_table_size = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table); + cur_offset += partition_table_size; + GELOGD("Cur model index %zu: ModelPartitionTable num :%u, " + "ModelFileHeader length :%zu, ModelPartitionTable length :%zu", + index, partition_table->num, sizeof(ModelFileHeader), partition_table_size); + if (model_data_size <= cur_offset) { + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", + partition_table->num, model_data_size); + return GE_EXEC_MODEL_DATA_SIZE_INVALID; + } + + for (uint32_t i = 0; i < partition_table->num; i++) { + ModelPartition partition; + partition.size = partition_table->partition[i].mem_size; + partition.data = model_data + cur_offset; + partition.type = partition_table->partition[i].type; + if (index >= model_contexts_.size()) { + if (index != model_contexts_.size()) { + GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", index); + return FAILED; + } + + OmFileContext tmp_ctx; + tmp_ctx.partition_datas_.push_back(partition); + model_contexts_.push_back(tmp_ctx); + } else { + model_contexts_[index].partition_datas_.push_back(partition); + } + + if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) { + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", + partition.size + cur_offset, model_data_size); + return GE_EXEC_MODEL_DATA_SIZE_INVALID; + } + cur_offset += partition.size; + GELOGD("Partition, type:%d, size:%u, model_index:%zu", static_cast(partition.type), partition.size, index); + } + } + if (cur_offset != model_data_size) { + GELOGE(FAILED, "do not get the complete model, read end offset:%zu, all size:%zu", cur_offset, model_data_size); + return FAILED; + } + return SUCCESS; +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector &OmFileSaveHelper::GetModelPartitions() const { return context_.partition_datas_; @@ -172,6 +270,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave return partition_table; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable( + size_t cur_ctx_index) { + auto &cur_ctx = model_contexts_[cur_ctx_index]; + auto partition_size = static_cast(cur_ctx.partition_datas_.size()); + // Build ModelPartitionTable, flex array + cur_ctx.partition_table_.clear(); + cur_ctx.partition_table_.resize(sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * partition_size, 0); + + auto partition_table = reinterpret_cast(cur_ctx.partition_table_.data()); + partition_table->num = partition_size; + + uint32_t mem_offset = 0; + for (uint32_t i = 0; i < partition_size; i++) { + ModelPartition partition = cur_ctx.partition_datas_[i]; + partition_table->partition[i] = {partition.type, mem_offset, partition.size}; + mem_offset += partition.size; + GELOGD("Partition, type:%d, size:%u", static_cast(partition.type), partition.size); + } + return partition_table; +} + + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); @@ -182,6 +302,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPar return SUCCESS; } +Status OmFileSaveHelper::AddPartition(ModelPartition &partition, size_t cur_index) { + if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { + GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); + return FAILED; + } + if (cur_index >= model_contexts_.size()) { + if (cur_index != model_contexts_.size()) { + GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", cur_index); + return FAILED; + } + OmFileContext tmp_ctx; + tmp_ctx.model_data_len_ += partition.size; + tmp_ctx.partition_datas_.push_back(partition); + model_contexts_.push_back(tmp_ctx); + } else { + model_contexts_[cur_index].model_data_len_ += partition.size; + model_contexts_[cur_index].partition_datas_.push_back(partition); + } + return SUCCESS; +} + Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, bool is_offline) { (void)save_param.cert_file; @@ -198,6 +339,10 @@ Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *outp Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) { #if !defined(NONSUPPORT_SAVE_TO_FILE) + if (context_.partition_datas_.empty()) { + GE_CHK_BOOL_EXEC(!model_contexts_.empty(), return FAILED, "mode contexts empty"); + context_ = model_contexts_.front(); + } uint32_t model_data_len = context_.model_data_len_; if (model_data_len == 0) { GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); @@ -231,4 +376,53 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat return SUCCESS; #endif } + +Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file, + ModelBufferData &model, bool is_offline) { + (void)save_param.cert_file; + (void)save_param.ek_file; + (void)save_param.encode_mode; + (void)save_param.hw_key_file; + (void)save_param.pri_key_file; + +#if !defined(NONSUPPORT_SAVE_TO_FILE) + vector model_partition_tabels; + vector> all_model_partitions; + for (size_t ctx_index = 0; ctx_index < model_contexts_.size(); ++ctx_index) { + auto &cur_ctx = model_contexts_[ctx_index]; + uint32_t cur_model_data_len = cur_ctx.model_data_len_; + if (cur_model_data_len == 0) { + GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); + return domi::PARAM_INVALID; + } + + auto tmp_table = GetPartitionTable(ctx_index); + if (tmp_table == nullptr) { + GELOGE(ge::GE_GRAPH_SAVE_FAILED, "SaveModelToFile execute failed: partition_table is NULL."); + return ge::GE_GRAPH_SAVE_FAILED; + } + uint32_t size_of_table = SIZE_OF_MODEL_PARTITION_TABLE(*tmp_table); + FMK_UINT32_ADDCHECK(size_of_table, cur_model_data_len) + FMK_UINT32_ADDCHECK(size_of_table + cur_model_data_len, model_header_.length) + model_header_.length += size_of_table + cur_model_data_len; + model_partition_tabels.push_back(tmp_table); + all_model_partitions.push_back(cur_ctx.partition_datas_); + GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu", + size_of_table, cur_model_data_len, ctx_index); + } + Status ret; + if (is_offline) { + ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions); + } else { + GELOGW("do not support save ge root model to buff now"); + return FAILED; + } + if (ret == SUCCESS) { + GELOGD("Save model success without encrypt."); + } + return ret; +#else + return SUCCESS; +#endif +} } // namespace ge diff --git a/ge/common/kernel_store.cc b/ge/common/kernel_store.cc old mode 100755 new mode 100644 diff --git a/ge/common/kernel_store.h b/ge/common/kernel_store.h old mode 100755 new mode 100644 diff --git a/ge/common/math/fp16_math.cc b/ge/common/math/fp16_math.cc old mode 100755 new mode 100644 diff --git a/ge/common/math/fp16_math.h b/ge/common/math/fp16_math.h old mode 100755 new mode 100644 diff --git a/ge/common/math/math_util.h b/ge/common/math/math_util.h old mode 100755 new mode 100644 diff --git a/ge/common/math_util.h b/ge/common/math_util.h old mode 100755 new mode 100644 diff --git a/ge/common/model_parser/base.h b/ge/common/model_parser/base.h old mode 100755 new mode 100644 diff --git a/ge/common/model_saver.cc b/ge/common/model_saver.cc old mode 100755 new mode 100644 diff --git a/ge/common/module.mk b/ge/common/module.mk old mode 100755 new mode 100644 diff --git a/ge/common/op/ge_op_utils.cc b/ge/common/op/ge_op_utils.cc index 579190d6..fc2990b6 100644 --- a/ge/common/op/ge_op_utils.cc +++ b/ge/common/op/ge_op_utils.cc @@ -357,7 +357,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::TransDataHWCK2KCH const char *w_data = (const char *)input; int64_t count = h * w * c * k; - GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return ); + GE_IF_BOOL_EXEC(count <= 0, GELOGW("Count value must be greater than 0, but count = %ld", count); return); float *buf = new (std::nothrow) float[count](); GE_RT_VOID_CHECK_NOTNULL(buf); float *src_buff = nullptr; diff --git a/ge/common/profiling/ge_profiling.cc b/ge/common/profiling/ge_profiling.cc new file mode 100644 index 00000000..640f77a1 --- /dev/null +++ b/ge/common/profiling/ge_profiling.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/profiling/ge_profiling.h" +#include "runtime/base.h" +#include "common/profiling/profiling_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "graph/load/graph_loader.h" +#include "init/gelib.h" +#include "framework/common/ge_inner_error_codes.h" + +namespace { +const uint32_t kDeviceListIndex = 3; +const std::string kDeviceNums = "devNums"; +const std::string kDeviceIdList = "devIdList"; +const std::string kProfilingInit = "prof_init"; +const std::string kProfilingFinalize = "prof_finalize"; +const std::string kProfilingStart = "prof_start"; +const std::string kProfilingStop = "prof_stop"; +const std::string kProfModelSubscribe = "prof_model_subscribe"; +const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; +const std::string kRtSetDeviceRegName = "profiling"; + +const std::map kProfCommandTypeMap = { + {kProfCommandhandleInit, kProfilingInit}, + {kProfCommandhandleStart, kProfilingStart}, + {kProfCommandhandleStop, kProfilingStop}, + {kProfCommandhandleFinalize, kProfilingFinalize}, + {kProfCommandhandleModelSubscribe, kProfModelSubscribe}, + {kProfCommandhandleModelUnsubscribe, kProfModelUnsubscribe}}; +} // namespace + +bool TransProfConfigToParam(const ProfCommandHandleData &profCommand, vector &prof_config_params) { + prof_config_params.clear(); + prof_config_params.emplace_back(kDeviceNums); + prof_config_params.emplace_back(std::to_string(profCommand.devNums)); + prof_config_params.emplace_back(kDeviceIdList); + std::string devID = ""; + if (profCommand.devNums == 0) { + GELOGW("The device num is invalid."); + return false; + } + for (uint32_t i = 0; i < profCommand.devNums; i++) { + devID.append(std::to_string(profCommand.devIdList[i])); + if (i != profCommand.devNums - 1) { + devID.append(","); + } + } + + prof_config_params.push_back(devID); + return true; +} + +bool isProfConfigValid(const uint32_t *deviceid_list, uint32_t device_nums) { + if (deviceid_list == nullptr) { + GELOGE(ge::PARAM_INVALID, "deviceIdList is nullptr"); + return false; + } + if (device_nums == 0 || device_nums > MAX_DEV_NUM) { + GELOGE(ge::PARAM_INVALID, "The device nums: %u is invalid.", device_nums); + return false; + } + + // real device num + int32_t dev_count = 0; + rtError_t rt_err = rtGetDeviceCount(&dev_count); + if (rt_err != RT_ERROR_NONE) { + GELOGE(ge::INTERNAL_ERROR, "Get the Device count fail."); + return false; + } + + if (device_nums > static_cast(dev_count)) { + GELOGE(ge::PARAM_INVALID, "Device num(%u) is not in range 1 ~ %d.", device_nums, dev_count); + return false; + } + + std::unordered_set record; + for (size_t i = 0; i < device_nums; ++i) { + uint32_t dev_id = deviceid_list[i]; + if (dev_id >= static_cast(dev_count)) { + GELOGE(ge::PARAM_INVALID, "Device id %u is not in range 0 ~ %d(exclude %d)", dev_id, dev_count, dev_count); + return false; + } + if (record.count(dev_id) > 0) { + GELOGE(ge::PARAM_INVALID, "Device id %u is duplicatedly set", dev_id); + return false; + } + record.insert(dev_id); + } + return true; +} + +ge::Status RegProfCtrlCallback(MsprofCtrlCallback func) { + if (func == nullptr) { + GELOGE(ge::PARAM_INVALID, "Msprof ctrl callback is nullptr."); + return ge::PARAM_INVALID; + } + if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofCtrlCallback != nullptr) { + GELOGW("Msprof ctrl callback is exist, just ignore it."); + } else { + GELOGI("GE register Msprof ctrl callback."); + ge::ProfilingManager::Instance().SetMsprofCtrlCallback(func); + } + return ge::SUCCESS; +} + +ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func) { + if (func == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofSetDeviceCallback callback is nullptr."); + return ge::PARAM_INVALID; + } + // Pass MsprofSetDeviceCallback to runtime + GELOGI("GE pass setdevice callback to runtime."); + ge::Status rt_ret = rtRegDeviceStateCallback(kRtSetDeviceRegName.c_str(), static_cast(func)); + if (rt_ret != ge::SUCCESS) { + GELOGE(rt_ret, "Pass MsprofSetDeviceCallback to runtime failed!"); + return rt_ret; + } + return ge::SUCCESS; +} + +ge::Status RegProfReporterCallback(MsprofReporterCallback func) { + if (func == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); + return ge::PARAM_INVALID; + } + if (ge::ProfilingManager::Instance().GetMsprofCallback().msprofReporterCallback != nullptr) { + GELOGW("Msprof reporter callback is exist, just ignore it."); + } else { + GELOGI("GE register Msprof reporter callback."); + ge::ProfilingManager::Instance().SetMsprofReporterCallback(func); + // Pass MsprofReporterCallback to runtime + ge::Status rt_ret = rtSetMsprofReporterCallback(func); + if (rt_ret != ge::SUCCESS) { + GELOGE(rt_ret, "Pass MsprofReporterCallback to runtime failed!!"); + return rt_ret; + } + // Pass MsprofReporterCallback to hccl + } + return ge::SUCCESS; +} + +ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len) { + if (type != kProfCommandhandleFinalize) { + GE_CHECK_NOTNULL(data); + } + ProfCommandHandleData *prof_config_param = (ProfCommandHandleData *)data; + auto iter = kProfCommandTypeMap.find(type); + if (iter == kProfCommandTypeMap.end()) { + GELOGW("The prof comand type is invalid."); + return ge::PARAM_INVALID; + } + std::vector prof_params; + if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { + if (!isProfConfigValid(prof_config_param->devIdList, prof_config_param->devNums)) { + return ge::FAILED; + } + + if (!TransProfConfigToParam(*prof_config_param, prof_params)) { + GELOGE(ge::PARAM_INVALID, "Transfer profilerConfig to string vector failed"); + return ge::PARAM_INVALID; + } + } + ge::GraphLoader graph_loader; + ge::Command command; + command.cmd_params.clear(); + command.cmd_type = iter->second; + command.cmd_params = prof_params; + if (type != kProfCommandhandleFinalize) { + command.module_index = prof_config_param->profSwitch; + } + GELOGI("GE commandhandle execute, Command Type: %d, data type config: 0x%llx", type, command.module_index); + if (type == kProfCommandhandleStart || type == kProfCommandhandleStop) { + GELOGI("Profiling device nums:%s , deviceID:[%s]", prof_params[0].c_str(), prof_params[kDeviceListIndex].c_str()); + } + ge::Status ret = graph_loader.CommandHandle(command); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Handle profiling command failed"); + return ge::FAILED; + } + + GELOGI("Successfully execute profiling command type: %d, command 0x%llx.", type, command.module_index); + return ge::SUCCESS; +} + diff --git a/ge/common/profiling/ge_runner_profiling.cc b/ge/common/profiling/ge_runner_profiling.cc new file mode 100644 index 00000000..067aafe3 --- /dev/null +++ b/ge/common/profiling/ge_runner_profiling.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/profiling/ge_runner_profiling.h" +#include "init/gelib.h" + +bool IsInitialize() { + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || instance_ptr->InitFlag() == false) { + return false; + } + return true; +} diff --git a/ge/common/profiling/profiling_manager.cc b/ge/common/profiling/profiling_manager.cc index 2f0f061f..456cb0a4 100644 --- a/ge/common/profiling/profiling_manager.cc +++ b/ge/common/profiling/profiling_manager.cc @@ -24,16 +24,9 @@ #include "graph/load/new_model_manager/davinci_model.h" namespace { -const char *const kJobID = "jobID"; -const char *const kDeviceID = "deviceID"; -const char *const kStartCfg = "startCfg"; -const char *const kFeatures = "features"; -const char *const kConf = "conf"; -const char *const kEvents = "events"; -const char *const kAiCoreEvents = "ai_core_events"; -const char *const kName = "name"; -const char *const kTraceID = "traceId"; -const char *const kProfDir = "resultPath"; +const char *const kTrainingTrace = "training_trace"; +const char *const kFpPoint = "fp_point"; +const char *const kBpPoint = "bp_point"; const size_t kReportMaxLen = 2048; const int32_t kMaxDeviceNum = 256; const std::string kConfigNumsdev = "devNums"; @@ -45,7 +38,13 @@ const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; } // namespace namespace ge { -ProfilingManager::ProfilingManager() : subscribe_count_(0) {} +ProfilingManager::ProfilingManager() : is_load_profiling_(false), + is_execute_profiling_(false), + is_training_trace_(false), + subscribe_count_(0) { + prof_cb_.msprofCtrlCallback = nullptr; + prof_cb_.msprofReporterCallback = nullptr; +} ProfilingManager::~ProfilingManager() {} @@ -58,44 +57,29 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In #ifdef DAVINCI_SUPPORT_PROFILING vector().swap(device_id_); subscribe_count_ = 0; - job_id_ = options.job_id; - - GELOGI("ProfilingManager::Init job_id:%s", job_id_.c_str()); - + GELOGI("ProfilingManager::Init job_id:%s", options.job_id.c_str()); - - Status ret; - if (!recv_profiling_config_.empty()) { - GELOGI("Profiling json config from acl:%s", recv_profiling_config_.c_str()); - ret = InitFromAclCfg(recv_profiling_config_); - } else { - ret = InitFromOptions(options); - if (ret == SUCCESS && is_load_profiling_) { - device_id_.push_back(options.device_id); - } - } + struct MsprofGeOptions prof_conf = {{ 0 }}; + Status ret = InitFromOptions(options, prof_conf); if (ret != SUCCESS) { GELOGE(ret, "Failed to init profiling."); return ret; } - if (is_load_profiling_) { - // register Framework to profiling - int result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); - if (result != 0) { - GELOGE(FAILED, "Register profiling engine failed."); - return FAILED; + if (is_execute_profiling_) { + if (prof_cb_.msprofCtrlCallback == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr."); + return ge::PARAM_INVALID; } - // profiling startup first time - GELOGI("Begin to init profiling, device num %zu", device_id_.size()); - for (size_t i = 0; i < device_id_.size(); ++i) { - ret = StartProfiling(0, device_id_[i]); - if (ret != SUCCESS) { - GELOGW("Profiling start failed on device %d.", device_id_[i]); - continue; - } - GELOGI("Profiling init succ on device %d.", device_id_[i]); + int32_t cb_ret = prof_cb_.msprofCtrlCallback( + static_cast(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS), + static_cast(&prof_conf), sizeof(MsprofGeOptions)); + if (cb_ret != 0) { + GELOGE(FAILED, "Call msprofCtrlCallback failed, type:%u, return:%d", + static_cast(MsprofCtrlCallbackType::MSPROF_CTRL_INIT_GE_OPTIONS), cb_ret); + return FAILED; } + GELOGI("Profiling init success"); } else { GELOGI("The profiling is off, skip the initialization"); } @@ -103,288 +87,116 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromAclCfg( - const std::string &config) { +ge::Status ProfilingManager::InitFromOptions(const Options &options, MsprofGeOptions &prof_conf) { #ifdef DAVINCI_SUPPORT_PROFILING - try { - is_load_profiling_ = false; - is_execute_profiling_ = false; - profiling_opts_.clear(); - op_trace_conf_.clear(); - Json start_prof_conf = Json::parse(config); - Json &prof_conf = start_prof_conf[kStartCfg][0]; - job_id_ = prof_conf[kJobID]; - auto iter = prof_conf.find(kProfDir); - if (iter != prof_conf.end()) { - prof_dir_ = prof_conf[kProfDir]; - } - Json &device_id = prof_conf[kDeviceID]; - if (device_id.size() != 0) { - vector().swap(device_id_); - bool is_all = false; - for (size_t i = 0; i < device_id.size(); i++) { - std::string device_id_str = device_id[i].get(); - if (device_id_str == "all") { - is_all = true; - break; - } - device_id_.push_back(std::stoi(device_id_str)); - } - if (is_all) { - int32_t count = 0; - rtError_t rt_err = rtGetDeviceCount(&count); - if (rt_err != RT_ERROR_NONE) { - GELOGE(FAILED, "Call rtGetDeviceCount to get device failed."); - } - - vector().swap(device_id_); - for (int32_t i = 0; i < count; ++i) { - device_id_.push_back(i); - } - } + // enable profiling by env + char env_profiling_mode[MMPA_MAX_PATH] = { 0x00 }; + is_load_profiling_ = false; // Change in ProfInit + is_execute_profiling_ = false; + + if (options.profiling_mode == "1" && !options.profiling_options.empty()) { + // enable profiling by ge option + if (memcpy_s(prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX, options.profiling_options.c_str(), + options.profiling_options.size()) != EOK) { + GELOGE(INTERNAL_ERROR, "copy profiling_options failed."); + return INTERNAL_ERROR; } - - Json &features = prof_conf[kFeatures]; - if (ParseFeaturesFromAclCfg(features) != SUCCESS) { - GELOGE(FAILED, "Parse feature from acl cfg failed."); - return FAILED; + is_execute_profiling_ = true; + GELOGI("The profiling in options is %s, %s. origin option: %s", options.profiling_mode.c_str(), + prof_conf.options, options.profiling_options.c_str()); + } else { + (void)mmGetEnv("PROFILING_MODE", env_profiling_mode, MMPA_MAX_PATH); + (void)mmGetEnv("PROFILING_OPTIONS", prof_conf.options, MSPROF_OPTIONS_DEF_LEN_MAX); + // The env is invalid + if ((strcmp("true", env_profiling_mode) != 0) || (strcmp(prof_conf.options, "\0") == 0)) { + return SUCCESS; } - is_load_profiling_ = true; + // enable profiling by env is_execute_profiling_ = true; - } catch (...) { - GELOGE(FAILED, "Json conf is not invalid !"); + GELOGI("The profiling in env is %s, %s", env_profiling_mode, prof_conf.options); + } + + if (!is_execute_profiling_) { + return SUCCESS; + } + + // Parse json str for bp fp + Status ret = ParseOptions(prof_conf.options); + if (ret != ge::SUCCESS) { + GELOGE(ge::PARAM_INVALID, "Parse training trace param failed."); return ge::PARAM_INVALID; } + + if (memcpy_s(prof_conf.jobId, sizeof(prof_conf.jobId), options.job_id.c_str(), + sizeof(options.job_id.c_str())) != EOK) { + GELOGE(INTERNAL_ERROR, "copy job_id failed."); + return INTERNAL_ERROR; + } #endif return ge::SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::ParseFeaturesFromAclCfg( - const Json &features) { -#ifdef DAVINCI_SUPPORT_PROFILING +ge::Status ProfilingManager::ParseOptions(const std::string &options) { + if (options.empty()) { + GELOGE(ge::PARAM_INVALID, "Profiling options is empty."); + return ge::PARAM_INVALID; + } try { - for (size_t i = 0; i < features.size(); ++i) { - const Json &feature = features[i]; - if ((feature.find(kName) == feature.end()) || feature[kName].is_null()) { - continue; - } - const std::string &name = feature[kName]; - if (name == "op_trace") { - const Json &conf = feature[kConf]; - const Json &events = conf[0][kEvents]; - const std::string &ai_core_events = events[0][kAiCoreEvents]; - GELOGI("Op trace config from acl ai_core_events:%s", ai_core_events.c_str()); - is_op_trace_ = true; - ProfMgrConf prof_mgr_conf; - int result = ProfMgrGetConf(ai_core_events, &prof_mgr_conf); - if (result != 0) { - GELOGE(FAILED, "ProfMgrGetConf failed."); - return FAILED; - } - op_trace_conf_ = prof_mgr_conf.conf; - op_trace_iter_num_ = static_cast(op_trace_conf_.size()); - GELOGI("Op trace profiling iter num %d,", op_trace_iter_num_); - } else if (name == "task_trace") { - is_op_trace_ = false; - if (feature.find(kConf) != feature.end()) { - const Json &conf = feature[kConf]; - std::stringstream task_trace_conf; - task_trace_conf << conf; - task_trace_conf_ = task_trace_conf.str(); - } - GELOGI("Task trace config from acl"); - } else if (name == "system_trace") { - is_op_trace_ = false; - const Json &conf = feature[kConf]; - std::stringstream system_trace_conf; - system_trace_conf << conf; - system_trace_conf_ = system_trace_conf.str(); - GELOGI("System trace config from acl"); - } - profiling_opts_.push_back(name); + Json prof_options = Json::parse(options); + const std::string training_trace = prof_options[kTrainingTrace]; + if (training_trace.empty()) { + GELOGI("Training trace will not take effect."); + return ge::SUCCESS; + } + GELOGI("GE profiling training trace:%s", training_trace.c_str()); + if (training_trace != "on") { + GELOGE(ge::PARAM_INVALID, "Training trace param:%s is invalid.", training_trace.c_str()); + return ge::PARAM_INVALID; + } + fp_point_ = prof_options[kFpPoint]; + bp_point_ = prof_options[kBpPoint]; + if (!fp_point_.empty() && !bp_point_.empty()) { + GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); } } catch (...) { - GELOGE(ge::PARAM_INVALID, "Json conf feature is not invalid !"); + GELOGE(FAILED, "Json prof_conf options is invalid."); return ge::PARAM_INVALID; } -#endif return ge::SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromOptions(const Options &options) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { #ifdef DAVINCI_SUPPORT_PROFILING - // enable profiling support two ways: env and front end - char profiling_mode_temp[MMPA_MAX_PATH] = { 0x00 }; - char prof_options_temp[MMPA_MAX_PATH] = { 0x00 }; - (void)mmGetEnv("PROFILING_MODE", profiling_mode_temp, MMPA_MAX_PATH); - (void)mmGetEnv("PROFILING_OPTIONS", prof_options_temp, MMPA_MAX_PATH ); - const char *profiling_mode = profiling_mode_temp; - const char *prof_options = prof_options_temp; - if ((profiling_mode == nullptr) || (strcmp("true", profiling_mode) != 0) || (prof_options == nullptr)) { - is_load_profiling_ = false; - is_execute_profiling_ = false; - } else { - std::string prof_options_str = std::string(prof_options); - profiling_opts_ = StringUtils::Split(prof_options_str, ':'); - is_load_profiling_ = true; - is_execute_profiling_ = true; - GELOGI("The profiling in env is %s, %s", profiling_mode, prof_options); - } - if (!is_load_profiling_) { - const std::string enable_profiling = "1"; - if (options.profiling_mode != enable_profiling || options.profiling_options.empty()) { - is_load_profiling_ = false; - is_execute_profiling_ = false; - return SUCCESS; - } else { - profiling_opts_ = StringUtils::Split(options.profiling_options, ':'); - is_load_profiling_ = true; - is_execute_profiling_ = true; - GELOGI("The profiling in options is %s, %s", options.profiling_mode.c_str(), options.profiling_options.c_str()); - } - } - // features:'training_trace', 'task_trace' or 'op_trace' etc - if (!profiling_opts_.empty()) { - if (profiling_opts_[0] == "op_trace") { - is_op_trace_ = true; - // op trace get conf - ProfMgrConf prof_mgr_conf; - int result = ProfMgrGetConf("", &prof_mgr_conf); - if (result != 0) { - GELOGE(FAILED, "ProfMgrGetConf failed."); - return FAILED; - } - op_trace_conf_ = prof_mgr_conf.conf; - op_trace_iter_num_ = static_cast(op_trace_conf_.size()); - GELOGI("op trace profiling iter num %d,", op_trace_iter_num_); - } else { - is_op_trace_ = false; - op_trace_iter_num_ = 1; + uint64_t module = GetProfilingModule(); + // The following if case will not be executed in normal case, inc case of ProfStopProfiling is abnormal + int32_t device_num = static_cast(device_id_.size()); + if (device_num != 0) { + auto device_id_ptr = std::unique_ptr(new (std::nothrow) uint32_t[device_num]); + if (device_id_ptr == nullptr) { + GELOGE(FAILED, "Stop profiling: device id ptr is null."); + return; } - } -#endif - return ge::SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::StartProfiling(int32_t iter_num, - int32_t device_id) { -#ifdef DAVINCI_SUPPORT_PROFILING - if (!profiling_opts_.empty()) { - GELOGI("Start profiling index is %d", iter_num); - // current one docker only use one device - Json p_device; - - try { - // profiling need physical_device_id - p_device[kDeviceID] = std::to_string(device_id); - p_device[kJobID] = job_id_; - p_device[kTraceID] = std::to_string(GetContext().TraceId()); - if (!prof_dir_.empty()) { - p_device[kProfDir] = prof_dir_; - GELOGI("Prof dir: %s.", prof_dir_.c_str()); - } - - Json features; - if (is_op_trace_) { - Json f; - f[kName] = "op_trace"; - Json conf; - if (op_trace_conf_.size() <= static_cast(iter_num)) { - GELOGE(FAILED, "Op trace iter num is invalid!"); - return FAILED; - } - Json events; - events[0] = nlohmann::json::parse(op_trace_conf_[iter_num]); - conf[0][kEvents] = events; - f[kConf] = conf; - features[0] = f; - if (iter_num == 0) { - is_load_ = true; - } - } else { - for (std::vector::size_type i = 0; i < profiling_opts_.size(); i++) { - Json f; - if (profiling_opts_[i] == "system_trace") { - f[kConf] = nlohmann::json::parse(system_trace_conf_); - } else if (profiling_opts_[i] == "task_trace") { - if (!task_trace_conf_.empty()) { - f[kConf] = nlohmann::json::parse(task_trace_conf_); - } - } - f[kName] = profiling_opts_[i]; - features[i] = f; - } - is_load_ = true; - } - p_device[kFeatures] = features; - // only one device, but sProfMgrStartUp API require for device list - Json devices; - devices[0] = p_device; - - Json start_cfg; - start_cfg[kStartCfg] = devices; - - // convert json to string - std::stringstream ss; - ss << start_cfg; - send_profiling_config_ = ss.str(); - GELOGI("Profiling config %s\n", send_profiling_config_.c_str()); - } catch (...) { - GELOGE(FAILED, "Op trace json conf is not invalid !"); - return FAILED; + for (int32_t i = 0; i < device_num; i++) { + device_id_ptr[i] = static_cast(device_id_[i]); } - - // runtime startup for profiling - uint64_t module = GetProfilingModule(); - int32_t device_num = 1; - uint32_t device_id_rt = static_cast(device_id); - GE_CHK_RT_RET(rtProfilerStart(module, device_num, &device_id_rt)); - - // call profiling startup API - ProfMgrCfg prof_cfg = {send_profiling_config_}; - void *prof_handle = ProfMgrStartUp(&prof_cfg); - if (prof_handle == nullptr) { - GELOGW("ProfMgrStartUp failed on device %d ", device_id); - return FAILED; + rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); } - GELOGD("StartProfiling, prof_handle: %p", prof_handle); - prof_handle_vec_.push_back(prof_handle); } -#endif - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProfiling() { -#ifdef DAVINCI_SUPPORT_PROFILING - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter != nullptr) { - int ret = reporter->Flush(); - GELOGI("Report data end, ret is %d", ret); + + // stop profiling + if (prof_cb_.msprofCtrlCallback == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofCtrlCallback callback is nullptr."); + return; } - uint64_t module = GetProfilingModule(); - int32_t device_num = static_cast(device_id_.size()); - auto device_id_ptr = std::unique_ptr(new (std::nothrow) uint32_t[device_num]); - if (device_id_ptr == nullptr) { - GELOGE(FAILED, "Stop profiling: device id ptr is null."); + int32_t cb_ret = prof_cb_.msprofCtrlCallback(static_cast(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE), + nullptr, 0); + if (cb_ret != 0) { + GELOGW("call msprofCtrlCallback failed, type:%u, return:%d", + static_cast(MsprofCtrlCallbackType::MSPROF_CTRL_FINALIZE), cb_ret); return; } - for (int32_t i = 0; i < device_num; i++) { - device_id_ptr[i] = static_cast(device_id_[i]); - } - rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); - } - - for (size_t i = 0; i < prof_handle_vec_.size(); ++i) { - int result = ProfMgrStop(prof_handle_vec_[i]); - if (result != 0) { - GELOGW("ProfMgr stop return fail:%d, handle:%p", result, prof_handle_vec_[i]); - } - } - vector().swap(prof_handle_vec_); - is_load_ = false; - recv_profiling_config_ = ""; GELOGI("Stop Profiling success."); #endif } @@ -392,12 +204,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( uint32_t model_id, const std::vector &task_desc_info, const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - GELOGI("Profiling report is nullptr!"); - return; - } - std::string data; for (const auto &task : task_desc_info) { std::string model_name = task.model_name; @@ -412,7 +218,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin .append(std::to_string(stream_id)).append(" ") .append(std::to_string(model_id)).append("\n")); - Msprof::Engine::ReporterData reporter_data{}; + ReporterData reporter_data{}; reporter_data.deviceId = device_id; reporter_data.data = (unsigned char *)data.c_str(); reporter_data.dataLen = data.size(); @@ -422,9 +228,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin return; } - ret = reporter->Report(&reporter_data); - if (ret != SUCCESS) { - GELOGE(ret, "Reporter data of task_desc_info fail!"); + int32_t cb_ret = CallMsprofReport(reporter_data); + if (cb_ret != 0) { + GELOGE(cb_ret, "Reporter data of task_desc_info failed, ret:%d", cb_ret); return; } } @@ -436,9 +242,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( uint32_t model_id, const std::vector &compute_graph_desc_info, const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); - std::string data; for (const auto &graph : compute_graph_desc_info) { data.append("model_name:") @@ -493,64 +296,52 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin } data.append(" model_id:").append(std::to_string(model_id)); - data.append("\n"); - Msprof::Engine::ReporterData reporter_data{}; - Report(device_id, data, *reporter, reporter_data); - + GraphDescReport(device_id, data); data.clear(); } #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( - const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, - Msprof::Engine::ReporterData &reporter_data) { +void ProfilingManager::GraphDescReport(const int32_t &device_id, const string &data) { #ifdef DAVINCI_SUPPORT_PROFILING + ReporterData reporter_data{}; + int ret = -1; + int32_t cb_ret = -1; size_t index = data.size() / kReportMaxLen; if (index >= 1) { reporter_data.deviceId = device_id; - int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); + ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); for (size_t i = 0; i < index; ++i) { reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * i; reporter_data.dataLen = kReportMaxLen; - ret = reporter.Report(&reporter_data); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); + cb_ret = CallMsprofReport(reporter_data); + GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); } reporter_data.dataLen = data.size() - kReportMaxLen * index; if (reporter_data.dataLen != 0) { reporter_data.data = (unsigned char *)data.c_str() + kReportMaxLen * index; - ret = reporter.Report(&reporter_data); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); + cb_ret = CallMsprofReport(reporter_data); + GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); } } else { reporter_data.deviceId = device_id; reporter_data.data = (unsigned char *)data.c_str(); reporter_data.dataLen = data.size(); - int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); + ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); - ret = reporter.Report(&reporter_data); - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); - } -#endif -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit(const std::string &module) const { -#ifdef DAVINCI_SUPPORT_PROFILING - int ret = Msprof::Engine::UnInit(module); - if (ret != SUCCESS) { - GELOGE(ret, "profiling plugin uninit failed, ret:%d", ret); + cb_ret = CallMsprofReport(reporter_data); + GE_IF_BOOL_EXEC(cb_ret != 0, GELOGE(cb_ret, "Reporter data of graph_desc_info failed, ret:%d", cb_ret); return;); } #endif } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( uint32_t model_id, const std::vector &task_desc_info, - const std::vector &compute_graph_desc_info, - bool check_device) { + const std::vector &compute_graph_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING int32_t logic_device_id = 0; rtError_t rt_ret = rtGetDevice(&logic_device_id); @@ -559,13 +350,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr return; } GELOGD("current logic_device_id:%d", logic_device_id); - if (check_device) { - auto ret = std::find(device_id_.begin(), device_id_.end(), logic_device_id); - if (ret == device_id_.end()) { - GELOGE(FAILED, "get valid phy_device_id failed, profiling report failed."); - return; - } - } GELOGD("start ProfilingTaskDescInfo."); ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id); GELOGD("start ProfilingGraphDescInfo."); @@ -574,11 +358,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr #endif } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::SetProfilingConfig( - const std::string &profiling_cfg) { - recv_profiling_config_ = profiling_cfg; -} - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetProfilingModule() { uint64_t module = PROF_MODEL_EXECUTE_MASK | PROF_RUNTIME_API_MASK | @@ -594,9 +373,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetP return module; } -void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, - uint32_t device_id, - uint64_t module) { +void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module) { #ifdef DAVINCI_SUPPORT_PROFILING if (prof_type == kProfModelSubscribe) { if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { @@ -608,9 +385,13 @@ void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, subs_dev_module_[device_id] = dev_info; } } else if (prof_type == kProfModelUnsubscribe) { - if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { - if (subs_dev_module_[device_id].subscribe_count > 0) { - subs_dev_module_[device_id].subscribe_count--; + auto iter = subs_dev_module_.find(device_id); + if (iter != subs_dev_module_.end()) { + if (iter->second.subscribe_count > 0) { + iter->second.subscribe_count--; + } + if (iter->second.subscribe_count == 0) { + subs_dev_module_.erase(iter); } } } else { @@ -626,10 +407,11 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) { // register framework to profiling - int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); - if (result != SUCCESS) { - GELOGE(FAILED, "Register profiling engine failed."); - return FAILED; + // register Framework to profiling + int32_t cb_ret = PluginInit(); + if (cb_ret != 0) { + GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret); + return cb_ret; } GELOGI("Prof subscribe: model load profiling on."); } @@ -647,7 +429,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module); // Report profiling data - Status p_ret = davinci_model->ReportProfilingData(false); + Status p_ret = davinci_model->ReportProfilingData(); if (p_ret != SUCCESS) { GELOGE(p_ret, "Report profiling data failed."); return p_ret; @@ -672,6 +454,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo auto iter = subs_dev_module_.find(device[0]); if (iter != subs_dev_module_.end()) { if (subs_dev_module_[device[0]].subscribe_count == 1) { + // The same device_id, only stop at last time rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device); if (rt_ret != RT_ERROR_NONE) { GELOGE(FAILED, "Runtime profiler stop failed."); @@ -679,15 +462,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfMo } } UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module); + } else { + GELOGE(FAILED, "The device_id:%u has not been subscribed, do not need to cancel.", device[0]); + return FAILED; } subscribe_count_--; if (subscribe_count_ == 0) { - int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE); - if (ret != SUCCESS) { - GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret); - return ret; - } + // profiling plugin uninit at last subscription + PluginUnInit(); } #endif return SUCCESS; @@ -700,11 +483,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn if (model_load_mask == PROF_MODEL_LOAD_MASK) { // register Framework to profiling - int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); - if (result != SUCCESS) { - GELOGE(FAILED, "Register profiling engine failed."); - return FAILED; + int32_t cb_ret = PluginInit(); + if (cb_ret != 0) { + GELOGE(cb_ret, "profiling plugin init failed, ret:%d", cb_ret); + return cb_ret; } + int32_t device_num = -1; rtError_t rt_ret = rtProfilerStart(model_load_mask, device_num, nullptr); if (rt_ret != RT_ERROR_NONE) { @@ -719,7 +503,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfIn if (training_trace_mask == PROF_TRAINING_TRACE_MASK) { is_training_trace_ = true; } - is_acl_api_mode_ = true; GELOGI("Prof init success."); #endif return SUCCESS; @@ -730,19 +513,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi std::lock_guard lock(mutex_); is_load_profiling_ = false; is_training_trace_ = false; - is_acl_api_mode_ = false; + is_execute_profiling_ = false; + + // profiling plugin uninit + PluginUnInit(); - int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE); - if (ret != SUCCESS) { - GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret); - } int32_t dev_num = -1; rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr); if (rt_ret != RT_ERROR_NONE) { GELOGE(FAILED, "Runtime profiler stop failed."); return FAILED; } - for (auto device_id_module : device_id_module_map_) { if (device_id_module.second != 0) { uint32_t device_id = static_cast(device_id_module.first); @@ -792,6 +573,7 @@ Status ProfilingManager::ProfParseDeviceId(const std::map return FAILED; } catch (std::out_of_range &) { GELOGE(FAILED, "Device num: %s is out of range.", iter->second.c_str()); + return FAILED; } catch (...) { GELOGE(FAILED, "Device num: %s cannot change to int.", iter->second.c_str()); return FAILED; @@ -859,7 +642,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt for (int32_t i = 0; i < device_num; i++) { device_id_ptr[i] = static_cast(device_list[i]); } - GELOGD("Runtime config param: 0x%llx, device num: %d.", module, device_num); + GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); if (rt_ret != RT_ERROR_NONE) { @@ -878,7 +661,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt GELOGW("Prof start: load model module is invalid."); } UpdateDeviceIdModuleMap(kProfStart, module, device_list); - GELOGD("Prof start profiling success."); + GELOGI("Prof start profiling success."); #endif return SUCCESS; } @@ -901,7 +684,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt for (int32_t i = 0; i < device_num; i++) { device_id_ptr[i] = static_cast(device_list[i]); } - GELOGD("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); + GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); if (rt_ret != RT_ERROR_NONE) { GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); @@ -921,7 +704,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfSt GELOGW("Prof stop: load model module is invalid."); } UpdateDeviceIdModuleMap(kProfStop, module, device_list); - GELOGD("Prof stop profiling success."); + GELOGI("Prof stop profiling success."); #endif return SUCCESS; } @@ -963,47 +746,90 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ProfilingManager::Profilin if (rt_ret != RT_ERROR_NONE) { GELOGE(rt_ret, "Runtime get logic_device_id failed, current logic_device_id:%d", logic_device_id); } - GELOGD("Current logic_device_id:%d", logic_device_id); + GELOGI("Current logic_device_id:%d", logic_device_id); bool execute_model_prof_on = false; auto iter = std::find(device_id_.begin(), device_id_.end(), logic_device_id); if (iter != device_id_.end()) { execute_model_prof_on = true; } - GELOGD("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on); - return is_execute_profiling_ || execute_model_prof_on; + GELOGI("Flag is_execute_profiling: %d, execute_model_prof_on: %d", is_execute_profiling_, execute_model_prof_on); + return execute_model_prof_on; } -/** - * @brief Profiling PluginImpl - */ -// PluginImpl static variable init -Msprof::Engine::Reporter *PluginImpl::reporter_ = nullptr; - -PluginImpl::PluginImpl(const std::string &module) : module_(module) { GELOGI("Create PluginImpl\n"); } - -int PluginImpl::Init(const Msprof::Engine::Reporter *reporter) { - GELOGI("PluginImpl init"); - reporter_ = const_cast(reporter); - return 0; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::PluginInit() const { + if (prof_cb_.msprofReporterCallback == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); + return ge::PARAM_INVALID; + } + return prof_cb_.msprofReporterCallback( + static_cast(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), + static_cast(MsprofReporterCallbackType::MSPROF_REPORTER_INIT), + nullptr, 0); } -int PluginImpl::UnInit() { - GELOGI("PluginImpl Uninit"); - reporter_ = nullptr; - return 0; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit() const { +#ifdef DAVINCI_SUPPORT_PROFILING + if (prof_cb_.msprofReporterCallback == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); + return; + } + int32_t cb_ret = prof_cb_.msprofReporterCallback( + static_cast(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), + static_cast(MsprofReporterCallbackType::MSPROF_REPORTER_UNINIT), + nullptr, 0); + if (cb_ret != 0) { + GELOGW("profiling plugin uninit failed, ret:%d", cb_ret); + } +#endif } -Msprof::Engine::PluginIntf *ProfilingEngineImpl::CreatePlugin() { - GELOGI(" Create Plugin"); - return new (std::nothrow) PluginImpl(GE_PROFILING_MODULE); +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::CallMsprofReport( + ReporterData &reporter_data) const { + if (prof_cb_.msprofReporterCallback == nullptr) { + GELOGE(ge::PARAM_INVALID, "MsprofReporterCallback callback is nullptr."); + return ge::PARAM_INVALID; + } + return prof_cb_.msprofReporterCallback( + static_cast(MsprofReporterModuleId::MSPROF_MODULE_FRAMEWORK), + static_cast(MsprofReporterCallbackType::MSPROF_REPORTER_REPORT), + static_cast(&reporter_data), sizeof(ReporterData)); } -int ProfilingEngineImpl::ReleasePlugin(Msprof::Engine::PluginIntf *plugin) { - if (plugin != nullptr) { - delete plugin; - plugin = nullptr; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpPoint( + std::string &fp_point, std::string &bp_point) { + // Env or options mode, fp_point_/bp_point_ have initiliazed on profiling init + if (!fp_point_.empty() && !bp_point_.empty()) { + fp_point = fp_point_; + bp_point = bp_point_; + GELOGI("Bp Fp have been initialized in env or options. bp_point: %s, fp_point: %s", bp_point.c_str(), fp_point.c_str()); + return; + } + // ProfApi mode and training trace is set + try { + char env_profiling_options[MSPROF_OPTIONS_DEF_LEN_MAX] = { 0x00 }; + INT32 ret = mmGetEnv("PROFILING_OPTIONS", env_profiling_options, MSPROF_OPTIONS_DEF_LEN_MAX); + if (ret != EN_OK) { + GELOGI("PROFILING_OPTIONS env is not exist."); + return; + } + GELOGI("Parse env PROFILING_OPTIONS:%s.", env_profiling_options); + Json prof_options = Json::parse(env_profiling_options); + + fp_point_ = prof_options[kFpPoint]; + bp_point_ = prof_options[kBpPoint]; + + fp_point = fp_point_; + bp_point = bp_point_; + if (!fp_point_.empty() && !bp_point_.empty()) { + GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); + } + } catch (...) { + GELOGE(FAILED, "Json prof options is invalid."); + return; } - return 0; + return; } + + } // namespace ge diff --git a/ge/common/profiling/profiling_manager.h b/ge/common/profiling/profiling_manager.h old mode 100755 new mode 100644 index 66cefc32..5fa4fac4 --- a/ge/common/profiling/profiling_manager.h +++ b/ge/common/profiling/profiling_manager.h @@ -26,9 +26,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" #include "external/register/register_types.h" -#include "toolchain/prof_engine.h" -#include "toolchain/prof_mgr_core.h" -#include "toolchain/prof_acl_api.h" +#include "toolchain/prof_callback.h" using std::map; using std::string; @@ -37,35 +35,33 @@ using Json = nlohmann::json; namespace { const std::string GE_PROFILING_MODULE = "Framework"; + // DataTypeConfig MASK + #define PROF_ACL_API_MASK 0x0001 + #define PROF_TASK_TIME_MASK 0x0002 + #define PROF_AICORE_METRICS_MASK 0x0004 + #define PROF_AICPU_TRACE_MASK 0x0008 + #define PROF_MODEL_EXECUTE_MASK 0x0010 + #define PROF_RUNTIME_API_MASK 0x0020 + #define PROF_RUNTIME_TRACE_MASK 0x0040 + #define PROF_SCHEDULE_TIMELINE_MASK 0x0080 + #define PROF_SCHEDULE_TRACE_MASK 0x0100 + #define PROF_AIVECTORCORE_METRICS_MASK 0x0200 + #define PROF_SUBTASK_TIME_MASK 0x0400 + #define PROF_TRAINING_TRACE_MASK 0x0800 + #define PROF_HCCL_TRACE_MASK 0x1000 + #define PROF_DATA_PROCESS_MASK 0x2000 + #define PROF_MODEL_LOAD_MASK 0x8000000000000000 + } // namespace namespace ge { struct DeviceSubsInfo { uint64_t module; uint32_t subscribe_count; }; -// register Plugin -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { - public: - explicit PluginImpl(const std::string &module); - ~PluginImpl() {} - - int Init(const Msprof::Engine::Reporter *reporter); - int UnInit(); - static Msprof::Engine::Reporter *GetPluginReporter() { return reporter_; } - private: - static Msprof::Engine::Reporter *reporter_; - std::string module_; -}; - -// register Engine -class ProfilingEngineImpl : public Msprof::Engine::EngineIntf { - public: - ProfilingEngineImpl() {} - ~ProfilingEngineImpl() {} - - Msprof::Engine::PluginIntf *CreatePlugin(); - int ReleasePlugin(Msprof::Engine::PluginIntf *plugin); +struct MsprofCallback { + MsprofCtrlCallback msprofCtrlCallback; + MsprofReporterCallback msprofReporterCallback; }; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { @@ -73,68 +69,54 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { ProfilingManager(); virtual ~ProfilingManager(); static ProfilingManager &Instance(); - ge::Status Init(const Options &options); - ge::Status InitFromOptions(const Options &options); - ge::Status InitFromAclCfg(const std::string &config); - ge::Status StartProfiling(int32_t iter, int32_t device_id); - void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); - ge::Status ProfModelSubscribe(uint64_t module, void *model); - ge::Status ProfModelUnsubscribe(void *model); - ge::Status ProfInit(uint64_t module); - ge::Status ProfFinalize(); - ge::Status ProfStartProfiling(uint64_t module, const std::map &config_para); - ge::Status ProfStopProfiling(uint64_t module, const std::map &config_para); + Status Init(const Options &options); + Status ProfInit(uint64_t module); + Status ProfFinalize(); + Status ProfStartProfiling(uint64_t module, const std::map &config_para); + Status ProfStopProfiling(uint64_t module, const std::map &config_para); + Status ProfModelSubscribe(uint64_t module, void *model); + Status ProfModelUnsubscribe(void *model); void StopProfiling(); - bool ProfilingOpTraceOn() const { return is_op_trace_; } - bool ProfilingLoadFlag() const { return is_load_; } bool ProfilingTrainingTraceOn() const { return is_training_trace_; } bool ProfilingModelLoadOn() const { return is_load_profiling_; } bool ProfilingModelExecuteOn() const; - bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // only used by command pattern - bool IsAclApiMode() const { return is_acl_api_mode_; } - int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; } + bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // is_execute_profiling_ only used by ge option and env void ReportProfilingData(uint32_t model_id, const std::vector &task_desc_info, - const std::vector &compute_graph_desc_info, - bool check_device); - void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, - Msprof::Engine::ReporterData &reporter_data); + const std::vector &compute_graph_desc_info); void ProfilingTaskDescInfo(uint32_t model_id, const std::vector &task_desc_info, const int32_t &device_id); void ProfilingGraphDescInfo(uint32_t model_id, const std::vector &compute_graph_desc_info, const int32_t &device_id); - void SetProfilingConfig(const string &profiling_cfg); - vector GetProfilingDeviceId() const { return device_id_; } - void PluginUnInit(const std::string &module) const; + Status PluginInit() const; + void PluginUnInit() const; + Status CallMsprofReport(ReporterData &reporter_data) const; + struct MsprofCallback &GetMsprofCallback() { return prof_cb_; } + void SetMsprofCtrlCallback(MsprofCtrlCallback func) { prof_cb_.msprofCtrlCallback = func; } + void SetMsprofReporterCallback(MsprofReporterCallback func) { prof_cb_.msprofReporterCallback = func; } + void GetFpBpPoint(std::string &fp_point, std::string &bp_point); private: - ge::Status ParseFeaturesFromAclCfg(const Json &feature); - ge::Status ProfParseParam(const std::map &config_para, int32_t &device_num, - vector &device_list); - ge::Status ProfParseDeviceId(const std::map &config_para, + Status InitFromOptions(const Options &options, MsprofGeOptions &prof_conf); + Status ParseOptions(const std::string &options); + Status ProfParseParam(const std::map &config_para, int32_t &device_num, + vector &device_list); + Status ProfParseDeviceId(const std::map &config_para, vector &device_list); uint64_t GetProfilingModule(); + void GraphDescReport(const int32_t &device_id, const string &data); void UpdateDeviceIdModuleMap(string prof_type, uint64_t module, const vector &device_list); - bool is_load_profiling_ = false; - bool is_execute_profiling_ = false; - bool is_op_trace_ = false; - bool is_load_ = false; - bool is_training_trace_ = false; - bool is_acl_api_mode_ = false; - int32_t op_trace_iter_num_ = 0; - string job_id_; - string prof_dir_; + void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); + + bool is_load_profiling_; + bool is_execute_profiling_; + bool is_training_trace_; vector device_id_; - vector op_trace_conf_; - vector profiling_opts_; - vector prof_handle_vec_; - string recv_profiling_config_; - string send_profiling_config_; - string system_trace_conf_; - string task_trace_conf_; - const ProfilingEngineImpl engine_; map device_id_module_map_; // key: device_id, value: profiling on module map subs_dev_module_; // key: device_id, value: profiling on module uint32_t subscribe_count_; std::mutex mutex_; + MsprofCallback prof_cb_; + std::string fp_point_; + std::string bp_point_; }; } // namespace ge #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ diff --git a/ge/common/singleton.h b/ge/common/singleton.h old mode 100755 new mode 100644 diff --git a/ge/common/tbe_kernel_store.cc b/ge/common/tbe_kernel_store.cc old mode 100755 new mode 100644 diff --git a/ge/common/tbe_kernel_store.h b/ge/common/tbe_kernel_store.h old mode 100755 new mode 100644 diff --git a/ge/common/thread_pool.h b/ge/common/thread_pool.h old mode 100755 new mode 100644 diff --git a/ge/common/types.cc b/ge/common/types.cc index 54dc769f..1cc70347 100644 --- a/ge/common/types.cc +++ b/ge/common/types.cc @@ -801,7 +801,7 @@ const uint32_t XRGB_CHN_NUM = 4; /// const bool DEFAULT_GLOBAL_POOLING = false; -const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// +const uint32_t MODEL_VERSION = 0x20000000; ///< Model version 2.0/// // Eltwise's input size const int ELTWISE_MIN_INPUT_SIZE = 2; diff --git a/ge/common/util.cc b/ge/common/util.cc index 480be3c1..0a343a83 100644 --- a/ge/common/util.cc +++ b/ge/common/util.cc @@ -51,14 +51,15 @@ namespace { * If such an exception is encountered during operation, * the proto file can be divided into several small files or the limit value can be increased. */ -const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. -const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M +const int kFileSizeOutLimitedOrOpenFailed = -1; +const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. +const int kWarningThreshold = 1073741824; // 536870912 * 2 536870912 represent 512M /// The maximum length of the file. -const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now +const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now const int kMaxBuffSize = 256; const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; -constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024; +constexpr uint32_t kMaxConfigFileByte = 10485760; // 10 * 1024 * 1024 } // namespace namespace ge { @@ -76,7 +77,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co std::string real_path = RealPath(file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == kFileSizeOutLimitedOrOpenFailed, return false, + "file size not valid."); std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); if (!fs.is_open()) { @@ -118,20 +120,20 @@ long GetFileLength(const std::string &input_file) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); unsigned long long file_length = 0; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); - return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); + mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); + return kFileSizeOutLimitedOrOpenFailed, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); return -1, "File[%s] size is 0, not valid.", input_file.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, - ErrorManager::GetInstance().ATCReportErrMessage( - "E19016", {"filepath", "filesize", "maxlen"}, - {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); - return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, - kMaxFileSizeLimit); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( + "E19016", {"filepath", "filesize", "maxlen"}, + {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); + return kFileSizeOutLimitedOrOpenFailed, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, + kMaxFileSizeLimit); return static_cast(file_length); } @@ -187,7 +189,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co std::streamsize size = file.tellg(); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast(kMaxFileSizeLimit), file.close(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast(kMaxFileSizeLimit), file.close(); return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); file.seekg(0, std::ios::beg); // [no need to check value] @@ -210,8 +212,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= MMPA_MAX_PATH) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19002", {"filepath", "size"}, {directory_path, std::to_string(MMPA_MAX_PATH)}); + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, + {directory_path, std::to_string(MMPA_MAX_PATH)}); GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), MMPA_MAX_PATH); return -1; } @@ -224,8 +226,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: if (ret != 0) { if (errno != EEXIST) { ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Can not create directory %s. Make sure the directory exists and writable.", - directory_path.c_str()); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -265,7 +266,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch std::string real_path = RealPath(file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( - "E19000", {"path", "errmsg"}, {file, strerror(errno)}); + "E19000", {"path", "errmsg"}, {file, strerror(errno)}); return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); @@ -301,13 +302,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha google::protobuf::io::IstreamInputStream input(&fs); bool ret = google::protobuf::TextFormat::Parse(&input, message); GE_IF_BOOL_EXEC( - !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); + !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); return ret; } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { - mmTimeval tv {}; + mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds @@ -315,7 +316,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t GetCurrentSecondTimestap() { - mmTimeval tv {}; + mmTimeval tv{}; int ret = mmGetTimeOfDay(&tv, nullptr); GE_LOGE_IF(ret != EN_OK, "Func gettimeofday may failed: ret=%d", ret); auto total_use_time = tv.tv_sec; // seconds @@ -350,8 +351,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= MMPA_MAX_PATH, - ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(MMPA_MAX_PATH)}); - return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH); + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, + {path, std::to_string(MMPA_MAX_PATH)}); + return "", "Path[%s] len is too long, it must be less than %d", path, MMPA_MAX_PATH); // Nullptr is returned when the path does not exist or there is no permission // Return absolute path when path is accessible @@ -385,16 +387,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) #ifdef __GNUC__ - std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; #else - std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; + std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; #endif GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ValidateStr(real_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, real_path, kPathValidReason}); - return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); + !ValidateStr(real_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, real_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); // The absolute path points to a file that is not readable if (mmAccess2(real_path.c_str(), M_R_OK) != EN_OK) { @@ -416,24 +418,25 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const } GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path.c_str()) >= MMPA_MAX_PATH, - ErrorManager::GetInstance().ATCReportErrMessage( - "E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)}); - return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), MMPA_MAX_PATH); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19002", {"filepath", "size"}, {file_path, std::to_string(MMPA_MAX_PATH)}); + return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), + MMPA_MAX_PATH); // A regular matching expression to verify the validity of the input file path // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) #ifdef __GNUC__ - std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; + std::string mode = "^[\u4e00-\u9fa5A-Za-z0-9./_-]+$"; #else - std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; + std::string mode = "^[a-zA-Z]:([\\\\/][^\\s\\\\/:*?<>\"|][^\\\\/:*?<>\"|]*)*([/\\\\][^\\s\\\\/:*?<>\"|])?$"; #endif GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ValidateStr(file_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {atc_param, file_path, kPathValidReason}); - return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); + !ValidateStr(file_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, file_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); std::string real_path = RealPath(file_path.c_str()); // Can get absolute path (file exists) diff --git a/ge/engine_manager/dnnengine_manager.h b/ge/engine_manager/dnnengine_manager.h old mode 100755 new mode 100644 diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index d7dfdc84..d59afd03 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -17,6 +17,7 @@ set(SRC_LIST "../common/dump/dump_properties.cc" "../common/dump/dump_manager.cc" "../common/dump/dump_op.cc" + "../common/profiling/ge_profiling.cc" "../graph/load/graph_loader.cc" "../graph/execute/graph_execute.cc" "../omm/csa_interact.cc" @@ -172,6 +173,7 @@ target_compile_definitions(ge_executor PRIVATE google=ascend_private $,OS_TYPE=WIN,OS_TYPE=0> $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + LOG_CPP ) target_include_directories(ge_executor PRIVATE @@ -244,7 +246,6 @@ target_link_libraries(ge_executor_shared PRIVATE mmpa graph register - msprof error_manager ascend_hal_stub ascend_protobuf diff --git a/ge/executor/ge_executor.cc b/ge/executor/ge_executor.cc old mode 100755 new mode 100644 index d03a8d7b..57ab7800 --- a/ge/executor/ge_executor.cc +++ b/ge/executor/ge_executor.cc @@ -283,7 +283,8 @@ Status GeExecutor::Initialize() { // Start profiling Options profiling_options; profiling_options.device_id = 0; - profiling_options.job_id = ""; + // job id need to be set, the value is meaningless; + profiling_options.job_id = "1"; ProfilingManager::Instance().Init(profiling_options); isInit_ = true; @@ -303,7 +304,7 @@ Status GeExecutor::Finalize() { // Stop profiling if (ProfilingManager::Instance().ProfilingOn()) { ProfilingManager::Instance().StopProfiling(); - ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE); + ProfilingManager::Instance().PluginUnInit(); } GELOGI("Uninit GeExecutor over."); @@ -638,7 +639,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { return ACL_ERROR_GE_INTERNAL_ERROR; } - std::shared_ptr hybrid_davinci_model = ModelManager::GetInstance()->GetHybridModel(model_id); + std::shared_ptr hybrid_davinci_model = + ModelManager::GetInstance()->GetHybridModel(model_id); if (hybrid_davinci_model != nullptr) { uint64_t session_id = hybrid_davinci_model->GetSessionId(); VarManagerPool::Instance().RemoveVarManager(session_id); diff --git a/ge/executor/module.mk b/ge/executor/module.mk index 9566ca64..34c2a37e 100644 --- a/ge/executor/module.mk +++ b/ge/executor/module.mk @@ -8,6 +8,7 @@ local_ge_executor_src_files := \ ../common/dump/dump_op.cc \ ../common/ge/plugin_manager.cc \ ../common/ge/op_tiling_manager.cc \ + ../common/profiling/ge_profiling.cc \ ../graph/load/graph_loader.cc \ ../graph/execute/graph_execute.cc \ ../omm/csa_interact.cc \ @@ -177,7 +178,6 @@ local_ge_executor_shared_library := \ libmmpa \ libgraph \ libregister \ - libmsprof \ liberror_manager \ local_ge_executor_ldflags := -lrt -ldl \ @@ -234,7 +234,6 @@ LOCAL_SHARED_LIBRARIES := \ libmmpa \ libgraph \ libregister \ - libmsprof \ liberror_manager \ stub/libascend_hal \ @@ -272,7 +271,6 @@ LOCAL_SHARED_LIBRARIES := \ libruntime \ libslog \ libmmpa \ - libmsprof \ LOCAL_LDFLAGS += $(local_ge_executor_ldflags) @@ -304,7 +302,6 @@ LOCAL_SHARED_LIBRARIES := \ libruntime \ libslog \ libmmpa \ - libmsprof \ ifeq ($(device_os),android) LOCAL_LDFLAGS += -ldl diff --git a/ge/ge_inference.mk b/ge/ge_inference.mk old mode 100755 new mode 100644 index 0987f148..bfb612ea --- a/ge/ge_inference.mk +++ b/ge/ge_inference.mk @@ -109,6 +109,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/atomic_addr_clean_pass.cc \ graph/passes/mark_same_addr_pass.cc \ graph/passes/mark_graph_unknown_status_pass.cc \ + graph/passes/dynamic_single_op_reset_shape_pass.cc \ graph/passes/mark_agnostic_pass.cc \ graph/common/omg_util.cc \ graph/common/bcast.cc \ @@ -164,6 +165,7 @@ OMG_HOST_SRC_FILES := \ host_kernels/slice_d_kernel.cc \ host_kernels/dynamic_stitch_kernel.cc \ host_kernels/identity_kernel.cc \ + host_kernels/reformat_kernel.cc \ graph/passes/stop_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \ graph/passes/identity_pass.cc \ diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt index 615a968f..8f5c9777 100755 --- a/ge/ge_local_engine/CMakeLists.txt +++ b/ge/ge_local_engine/CMakeLists.txt @@ -203,6 +203,7 @@ target_compile_options(ge_local_opskernel_builder_static PRIVATE target_compile_definitions(ge_local_opskernel_builder_static PRIVATE google=ascend_private + LOG_CPP ) target_include_directories(ge_local_opskernel_builder_static PRIVATE diff --git a/ge/ge_local_engine/engine/ge_local_engine.cc b/ge/ge_local_engine/engine/ge_local_engine.cc old mode 100755 new mode 100644 diff --git a/ge/ge_local_engine/engine/host_cpu_engine.cc b/ge/ge_local_engine/engine/host_cpu_engine.cc old mode 100755 new mode 100644 index f1e152f4..c836d4d6 --- a/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -14,7 +14,6 @@ * limitations under the License. */ #include "host_cpu_engine.h" -#include #include "graph/common/omg_util.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_adapter.h" @@ -96,8 +95,8 @@ Status GetDataNumber(const GeTensorDesc &out_desc, uint64_t &data_num) { void HostCpuEngine::CloseSo() { for (auto handle : lib_handles_) { - if (dlclose(handle) != 0) { - GELOGW("failed to close handle, message: %s", dlerror()); + if (mmDlclose(handle) != 0) { + GELOGW("failed to close handle, message: %s", mmDlerror()); } } lib_handles_.clear(); @@ -323,13 +322,13 @@ Status HostCpuEngine::LoadLibs(std::vector &lib_paths) { Status HostCpuEngine::LoadLib(const std::string &lib_path) { GELOGI("To invoke dlopen on lib: %s", lib_path.c_str()); - auto handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + auto handle = mmDlopen(lib_path.c_str(), MMPA_RTLD_NOW | MMPA_RTLD_GLOBAL); if (handle == nullptr) { - GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), dlerror()); + GELOGE(INTERNAL_ERROR, "Failed to invoke dlopen. path = %s, error = %s", lib_path.c_str(), mmDlerror()); return INTERNAL_ERROR; } - auto initialize = (Status (*)(const HostCpuContext &))dlsym(handle, "Initialize"); + auto initialize = (Status (*)(const HostCpuContext &))mmDlsym(handle, "Initialize"); if (initialize != nullptr) { GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str()); if (initialize(HostCpuContext()) != SUCCESS) { diff --git a/ge/ge_local_engine/module.mk b/ge/ge_local_engine/module.mk old mode 100755 new mode 100644 diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc old mode 100755 new mode 100644 diff --git a/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h b/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h old mode 100755 new mode 100644 diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc old mode 100755 new mode 100644 diff --git a/ge/ge_local_engine/ops_kernel_store/op/no_op.cc b/ge/ge_local_engine/ops_kernel_store/op/no_op.cc old mode 100755 new mode 100644 diff --git a/ge/ge_runner.mk b/ge/ge_runner.mk index a2679ed1..25718e9b 100644 --- a/ge/ge_runner.mk +++ b/ge/ge_runner.mk @@ -29,6 +29,8 @@ LIBGE_LOCAL_SRC_FILES := \ common/dump/dump_manager.cc \ common/dump/dump_properties.cc \ common/dump/dump_op.cc \ + common/profiling/ge_profiling.cc \ + common/profiling/ge_runner_profiling.cc \ engine_manager/dnnengine_manager.cc \ ge_local_engine/engine/host_cpu_engine.cc \ generator/ge_generator.cc \ @@ -111,6 +113,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/atomic_addr_clean_pass.cc \ graph/passes/mark_same_addr_pass.cc \ graph/passes/mark_graph_unknown_status_pass.cc \ + graph/passes/dynamic_single_op_reset_shape_pass.cc \ graph/passes/mark_agnostic_pass.cc \ graph/partition/dynamic_shape_partition.cc \ graph/partition/stage_partition.cc \ @@ -170,6 +173,7 @@ LIBGE_LOCAL_SRC_FILES := \ host_kernels/sub_kernel.cc \ host_kernels/transdata_kernel.cc \ host_kernels/unpack_kernel.cc \ + host_kernels/reformat_kernel.cc \ graph/passes/folding_pass.cc \ graph/passes/get_original_format_pass.cc \ graph/passes/guarantee_const_pass.cc \ @@ -306,7 +310,6 @@ LIBGE_LOCAL_SRC_FILES := \ LIBCLIENT_LOCAL_SRC_FILES := \ proto/ge_api.proto \ client/ge_api.cc \ - client/ge_prof.cc \ RUNNER_LOCAL_C_INCLUDES := \ $(LOCAL_PATH) ./ \ @@ -371,7 +374,7 @@ LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) LOCAL_STATIC_LIBRARIES := libge_memory \ libadump_server \ - libmsprofiler \ + libmsprofiler_fwk \ libmmpa \ LOCAL_SHARED_LIBRARIES := \ @@ -381,7 +384,6 @@ LOCAL_SHARED_LIBRARIES := \ libgraph \ libregister \ libge_common \ - libmsprof \ liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -408,7 +410,6 @@ endif LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ - ../../out/ge/lib64/stub/ge_prof.cc \ ../../out/ge/lib64/stub/ge_ir_build.cc \ LOCAL_SHARED_LIBRARIES := @@ -464,7 +465,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ libmmpa \ - libmsprof \ LOCAL_LDFLAGS := -lrt -ldl @@ -497,7 +497,6 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ libmmpa \ - libmsprof \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/ge/ge_runtime/module.mk b/ge/ge_runtime/module.mk old mode 100755 new mode 100644 diff --git a/ge/ge_runtime/runtime_model.cc b/ge/ge_runtime/runtime_model.cc index fb0f3e85..8baa5b05 100644 --- a/ge/ge_runtime/runtime_model.cc +++ b/ge/ge_runtime/runtime_model.cc @@ -28,6 +28,7 @@ namespace ge { namespace model_runner { +const int kOffsetUnit = 8; RuntimeModel::~RuntimeModel() { GELOGI("RuntimeModel destructor start"); @@ -495,7 +496,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model return false; } uint64_t *buff = reinterpret_cast(const_cast(constant->weight_data.data())); - int64_t offset = elem_num * 8; + int64_t offset = elem_num * kOffsetUnit; uintptr_t hbm_raw_data_base_addr = reinterpret_cast(constant->output_addrs[0]) + offset; for (int64_t i = elem_num - 1; i >= 0; --i) { buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 16d63f6b..7c083d2b 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -47,6 +47,8 @@ const char *const kEngineNameDefault = "default"; const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; const char *const kFileNameSuffix = "online"; +const size_t kDynamicDimSize = 1; +const int64_t kDynamicDimValue = -2; std::map engine_type_map{ {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}}; @@ -156,7 +158,12 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen } string op_type; - if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) { + bool is_const = false; + (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const); + if (is_const) { + GELOGD("Get input[%d] is const", index); + op_type = CONSTANTOP; + } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) { op_type = DATA; } @@ -165,6 +172,18 @@ static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTen if (data_op == nullptr) { return FAILED; } + if (is_const) { + ConstGeTensorPtr tensor_value; + if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) { + GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str()); + return FAILED; + } + if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) { + GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail."); + return FAILED; + } + } + (void)AttrUtils::SetBool(data_op, "_is_single_op", true); GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail."); @@ -231,6 +250,43 @@ static void GetOpsProtoPath(string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } +static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) { + GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); + change_shape_flag = false; + for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { + auto input_desc = op_desc->MutableInputDesc(static_cast(i)); + GE_CHECK_NOTNULL(input_desc); + // pass scalar input desc + auto dims = input_desc->GetShape().GetDims(); + if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) { + change_shape_flag = true; + } + } + return SUCCESS; +} + +static void ResetTensorVecShape(const vector &inputs, vector &inputs_dynamic) { + for (auto input : inputs) { + auto input_desc = input.GetTensorDesc(); + GeShape shape_ori = input_desc.GetShape(); + + std::vector dynamic_shape_dims = {kDynamicDimValue}; + GeShape dynamic_shape(dynamic_shape_dims); + + ge::GeTensor inputTensor; + ge::GeTensorDesc desc(input_desc); + + bool is_const = false; + (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); + if (!is_const && shape_ori.GetDims().size() > 0) { + desc.SetShape(dynamic_shape); + } + + inputTensor.SetTensorDesc(desc); + inputs_dynamic.push_back(inputTensor); + } +} + class GeGenerator::Impl { public: Impl(OmgContext &omg_context) : omg_context_(omg_context) {} @@ -240,6 +296,8 @@ class GeGenerator::Impl { Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); + Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff); + Status SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs); @@ -505,19 +563,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); - ModelHelper model_helper; - string model_name = ""; - Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name); - if (name_ret != SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); - GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); - return PARAM_INVALID; - } - map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); - GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; - GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null"); - ge_model->SetName(model_name); - ret = impl_->SaveModel(file_name_prefix, ge_model, model); + ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); if (ret != SUCCESS) { GELOGE(ret, "Save model failed"); if (impl_->graph_manager_.Finalize() != SUCCESS) { @@ -567,6 +613,9 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline) { + if (!is_offline) { + (void)AttrUtils::SetBool(op_desc, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true); + } if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) { GELOGE(PARAM_INVALID, "input param is invalid when build single op!"); @@ -594,40 +643,11 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in // 2. Create ComputeGraph. string name = ge::CurrentTimeInStr() + "_" + model_file_name; - ge::ComputeGraphPtr compute_graph = MakeShared(name); - GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); - - // 3. Add Node to ComputeGraph. - NodePtr op_node = compute_graph->AddNode(op_desc); - GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); - - // 4. Create InputData node. - int32_t arg_index = 0; - if (inputs.empty()) { - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); - if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { - continue; - } - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); - arg_index++; - } - } else { - for (const auto &in_desc : inputs) { - GeTensorDesc input_desc = in_desc.GetTensorDesc(); - GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); - arg_index++; - } + Graph graph; + if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) { + GELOGE(GRAPH_FAILED, "make graph fail."); + return GRAPH_FAILED; } - - // 5. Create Output node. - if (!outputs.empty()) { - GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); - } - - // dump ComputeGraph. - compute_graph->Dump(); - Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); GELOGI("ATC parser success in single op build."); GeRootModelPtr ge_root_model = nullptr; @@ -644,7 +664,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in } GeModelPtr &ge_model = name_to_ge_model.begin()->second; GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str()); - GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); + + bool dynamic_flag = false; + if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) { + vector inputs_dynamic; + vector outputs_dynamic; + ResetTensorVecShape(inputs, inputs_dynamic); + ResetTensorVecShape(outputs, outputs_dynamic); + GE_CHK_STATUS_RET_NOLOG( + impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic)); + } else { + GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); + } GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff)); return SUCCESS; } @@ -683,6 +714,46 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, + const vector &outputs, std::string graph_name, Graph &graph) { + ge::ComputeGraphPtr compute_graph = MakeShared(graph_name); + GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); + + // 1. Add Node to ComputeGraph. + NodePtr op_node = compute_graph->AddNode(op_desc); + GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); + + // 2. Create InputData node. + int32_t arg_index = 0; + if (inputs.empty()) { + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); + if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { + continue; + } + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); + arg_index++; + } + } else { + for (const auto &in_desc : inputs) { + GeTensorDesc input_desc = in_desc.GetTensorDesc(); + GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); + arg_index++; + } + } + + // 3. Create Output node. + if (!outputs.empty()) { + GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); + } + + // dump ComputeGraph node. + compute_graph->Dump(); + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + return SUCCESS; +} + Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs) { GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); @@ -712,6 +783,44 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr & return SUCCESS; } +Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model, + ModelBufferData &model_buff) { + bool is_unknown_shape = false; + auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); + if (ret != SUCCESS) { + GELOGE(FAILED, "Check root model is unkonwn shape failed"); + return FAILED; + } + GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape); + GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED, + "ge root model has no sub model") + GeModelPtr model_root = nullptr; + if (is_unknown_shape) { + model_root = make_shared(); + model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); + ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); + model_root->SetName(ge_root_model->GetRootGraph()->GetName()); + } else { + model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; + } + // set atc version + if (!SetAtcVersionInfo(*(model_root.get()))) { + GELOGW("SetPackageVersionInfo of atc failed!"); + } + // set opp version + if (!SetOppVersionInfo(*(model_root.get()))) { + GELOGW("SetPackageVersionInfo of ops failed!"); + } + ModelHelper model_helper; + model_helper.SetSaveMode(is_offline_); + ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape); + if (ret != SUCCESS) { + GELOGE(ret, "Save to om model failed"); + return ret; + } + return SUCCESS; +} + Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector &inputs, GeRootModelPtr &ge_root_model) { static std::atomic atomic_graph_id(0); diff --git a/ge/graph/build/graph_builder.cc b/ge/graph/build/graph_builder.cc index bdb02b3a..87d2a206 100644 --- a/ge/graph/build/graph_builder.cc +++ b/ge/graph/build/graph_builder.cc @@ -349,7 +349,8 @@ static Status GenerateTaskForConstant(const std::shared_ptr &graph GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { - GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); + GELOGE(FAILED, "Insert memcpy between %s and %s failed.", + in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } @@ -475,7 +476,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr } Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { - // set input_desc.size = src_node.output_desc.size + // Set the size of input_desc to 'src_node.output_desc.size' if (node_ptr->GetType() == DATA) { bool is_unknown_shape = false; GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), @@ -498,7 +499,7 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { GE_IF_BOOL_EXEC(src_op == nullptr, continue); auto node_op_desc = node_ptr->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - // set dst_node.input_desc = src_node.output_desc + // Set the input_desc of dst_node to 'src_node.output_desc' auto output_desc = src_op->GetOutputDescPtr(peer_out_anchor->GetIdx()); int64_t size = 0; GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed!")); @@ -512,7 +513,6 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { auto input_desc = node_op_desc->MutableInputDesc(in_data_anchor->GetIdx()); GE_CHECK_NOTNULL(input_desc); (void) ge::TensorUtils::SetSize(*input_desc, size); - GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); diff --git a/ge/graph/build/memory/CMakeLists.txt b/ge/graph/build/memory/CMakeLists.txt index bdd869a9..e988b4ce 100644 --- a/ge/graph/build/memory/CMakeLists.txt +++ b/ge/graph/build/memory/CMakeLists.txt @@ -18,6 +18,7 @@ target_compile_options(ge_memory PRIVATE target_compile_definitions(ge_memory PRIVATE google=ascend_private + LOG_CPP ) target_link_libraries(ge_memory PRIVATE diff --git a/ge/graph/build/memory/binary_block_mem_assigner.cc b/ge/graph/build/memory/binary_block_mem_assigner.cc index ecd2488c..fff589f3 100644 --- a/ge/graph/build/memory/binary_block_mem_assigner.cc +++ b/ge/graph/build/memory/binary_block_mem_assigner.cc @@ -21,8 +21,8 @@ namespace { const uint32_t kRangeCeilInterval = 2; const uint32_t kLogBase = 2; -const int64_t kLargeBlockSize = 8 * 1024 * 1024; -const int64_t kLargeBlockRangeSize = 10; +const int64_t kLargeBlockSize = 8388608; // 8 * 1024 * 1024 +const int64_t kLargeBlockRangeSize = 2; } // namespace namespace ge { @@ -73,15 +73,17 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { GELOGE(FAILED, "dividend is 0!"); return FAILED; } + // Memory size is 512 aligned, so it is not necessary to take less than 512 + int64_t min_memory_size = (all_memory_size.back() > MEM_ALIGN_SIZE) ? MEM_ALIGN_SIZE : all_memory_size.front(); auto range_number = static_cast( - ceil(log(all_memory_size.back() / static_cast(all_memory_size.front())) / log(kLogBase))); + ceil(log(all_memory_size.back() / static_cast(min_memory_size)) / log(kLogBase))); range_number = (range_number == 0) ? 1 : range_number; GELOGD("Range number: %zu", range_number); vector> ranges(range_number); GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0."); size_t range_number_limit = all_memory_size.size() / range_number; - int64_t range_ceil = all_memory_size[0]; + int64_t range_ceil = min_memory_size; for (size_t i = 1; i <= range_number; i++) { GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast(range_ceil), kRangeCeilInterval), GELOGE(FAILED, "Multiply result is out of range."); @@ -114,7 +116,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { range_ceils.push_back(range.back()); } } - GELOGD("Range ceils: %s", ToString(range_ceils).c_str()); + GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); return SUCCESS; } diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc old mode 100755 new mode 100644 index 00f47573..9dc0cf73 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -65,6 +65,98 @@ void AlignMemOffset(size_t &mem_align_size) { mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; } +static bool CompareLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { + auto left_node_op_desc = left.node->GetOpDesc(); + auto right_node_op_desc = right.node->GetOpDesc(); + if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr) + && (left_node_op_desc->GetId() < right_node_op_desc->GetId())) { + return true; + } + return false; +} + +void GetLifeList(const MemoryBlock &block, std::vector &life_list, bool child) { + for (auto &node : block.NodeTypeIndexList()) { + life_list.emplace_back(node); + } + + if (child) { + for (auto child_block : block.ChildBlockList()) { + if (child_block == nullptr) { + continue; + } + if (block.stream_id_ != child_block->stream_id_ || !block.same_stream_ || !child_block->same_stream_) { + life_list.clear(); + return; + } + GetLifeList(*child_block, life_list, child); + } + } +} + +bool CrossLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) { + if ((left.node == nullptr) || (right.node == nullptr)) { + return true; + } + auto left_node_op_desc = left.node->GetOpDesc(); + auto right_node_op_desc = right.node->GetOpDesc(); + if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)) { + if (left_node_op_desc->GetId() < right_node_op_desc->GetId()) { + if (left.life_time_end >= static_cast(right_node_op_desc->GetId())) { + return true; + } + } else if (left_node_op_desc->GetId() == right_node_op_desc->GetId()) { + return true; + } else { + if (right.life_time_end >= static_cast(left_node_op_desc->GetId())) { + return true; + } + } + } + return false; +} + +/// +/// When child block's life time are not cross with parent block, they can be reused(only same stream). +/// |-----------------------------parent block---------------------| +/// |------child block1--------------||------child block2------| +/// |--child block1-1-| +/// +bool CanIntervalLifeReuse(MemoryBlock &parent_block, MemoryBlock &child_block) { + // judge by interval life time, only same stream can be judged by interval life time + if (parent_block.stream_id_ != child_block.stream_id_ || !parent_block.same_stream_ || !child_block.same_stream_ + || parent_block.NodeTypeIndexList().empty() || child_block.NodeTypeIndexList().empty()) { + return false; + } + + // quick judge by front and back node + if (CrossLifeTime(parent_block.NodeTypeIndexList().front(), child_block.NodeTypeIndexList().front())) { + return false; + } + if (CrossLifeTime(parent_block.NodeTypeIndexList().back(), child_block.NodeTypeIndexList().back())) { + return false; + } + + std::vector life_list; + GetLifeList(parent_block, life_list, false); + GetLifeList(child_block, life_list, true); + if (life_list.empty()) { + return false; + } + std::sort(life_list.begin(), life_list.end(), CompareLifeTime); + size_t pre_life_end = 0; + for (auto &node : life_list) { + auto node_op_desc = node.node->GetOpDesc(); + if (node_op_desc != nullptr && pre_life_end >= static_cast(node_op_desc->GetId())) { + // life time cross + return false; + } + pre_life_end = node.life_time_end; + } + GELOGI("Block size[%zu, %zu] life time are not cross.", parent_block.Size(), child_block.Size()); + return true; +} + void MemoryBlock::SetHeadOffset(size_t offset) { head_offset_ = offset; size_t child_offset = head_offset_; @@ -125,20 +217,12 @@ size_t MemoryBlock::AlignSize() const { return align_block_size; } -bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { - if (node_type_index_list_.empty()) { +bool MemoryBlock::IsSameBatchLabel() { + // only same batch label can reuse + if (batch_label_.empty() || node_type_index_list_.empty()) { return false; } - auto node_op_desc = node_type_index_list_[0].node->GetOpDesc(); - if (node_op_desc == nullptr) { - return false; - } - // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter - (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label); - if (first_batch_label.empty()) { - return false; - } bool all_same_label = true; for (size_t index = 1; index < node_type_index_list_.size(); ++index) { if (node_type_index_list_[index].node == nullptr) { @@ -147,8 +231,9 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { std::string batch_label; auto index_op_desc = node_type_index_list_[index].node->GetOpDesc(); GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue); + // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); - if (first_batch_label != batch_label) { + if (batch_label_ != batch_label) { all_same_label = false; break; } @@ -197,7 +282,7 @@ void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLi } void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { - if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { + if (CanNotLifeReuse(this) || CanNotLifeReuse(block) || (batch_label_ != block->batch_label_)) { return; } if (block->continuous_block_) { @@ -207,16 +292,27 @@ void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_ MemoryBlock *parent = nullptr; MemoryBlock *child = nullptr; // merge small block to large block - if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { - if ((child_offset_ + block->AlignSize()) <= AlignSize()) { - parent = this; - child = block; - } else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) { - parent = block; - child = this; + // noalign size 802816 + 802816 = 1605632 can reuse + // after 32 align size 802848 + 802848 > 1605664 can't reuse + // after 512 align size 803328 + 803328 > 1606144 can't reuse + // so 803328 + 803328 = 1606144 + 512 can reuse + if ((child_offset_ + block->AlignSize()) <= (AlignSize() + MEM_ALIGN_SIZE)) { + parent = this; + child = block; + } else if ((block->child_offset_ + AlignSize()) <= (block->AlignSize() + MEM_ALIGN_SIZE)) { + parent = block; + child = this; + } + + if ((parent != nullptr) && (child != nullptr)) { + // Different streams must use stream dependency to judge the life cycle + // In case same stream if it has child block, can judge all the child block's life time in CanIntervalLifeReuse + bool can_block_life_reuse = (child->child_blocks_.empty() + && (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd())); + if (!can_block_life_reuse && !CanIntervalLifeReuse(*parent, *child)) { + return; } - } - if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { + parent->child_blocks_.emplace_back(child); parent->child_offset_ += child->AlignSize(); child->deleted_block_ = true; @@ -261,6 +357,7 @@ size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &tota void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, std::map &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { GE_CHECK_NOTNULL_EXEC(node, return); + GE_CHECK_NOTNULL_EXEC(org_node, return); auto node_desc = node->GetOpDesc(); GE_CHECK_NOTNULL_EXEC(node_desc, return); auto node_id = node_desc->GetId(); @@ -415,12 +512,60 @@ BlockMemAssigner::~BlockMemAssigner() { } } +void GetMaxBatchAllMemorySize(std::map> &batch_all_memory_size, + std::map batch_total_size, vector &all_memory_size, + std::string &max_batch_label) { + // use max batch all memory size for reuse range + int64_t max_batch_size = 0; + for (const auto &it : batch_total_size) { + GELOGI("Batch[%s] total memory size[%ld]", it.first.c_str(), it.second); + // no batch label + if (it.first.empty()) { + continue; + } + if (it.second > max_batch_size) { + max_batch_size = it.second; + max_batch_label = it.first; + } + } + GELOGI("Max batch[%s] total memory size[%ld]", max_batch_label.c_str(), max_batch_size); + + for (const auto &it : batch_all_memory_size) { + if (it.first.empty() || (it.first == max_batch_label)) { + all_memory_size.insert(all_memory_size.end(), it.second.begin(), it.second.end()); + } + } + // all_memory_size can't be empty + if (all_memory_size.empty()) { + all_memory_size.emplace_back(MEM_ALIGN_SIZE); + } + sort(all_memory_size.begin(), all_memory_size.end()); + GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); + + for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { + if (*iter == 0) { + iter = all_memory_size.erase(iter); + } else { + ++iter; + } + } +} + void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { vector temp; + std::map> batch_all_memory_size; + std::map batch_total_size; for (const NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); + if (CheckIsZeroMemNodeType(node_op_desc->GetType())) { + continue; + } + + std::string batch_label; + (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (node_op_desc->GetType() == ATOMICADDRCLEAN) { atomic_addr_clean_id_ = node_op_desc->GetId(); } @@ -434,9 +579,14 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { if (!reuse_input) { int64_t size = 0; GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); - if (anchor_to_symbol_.empty()) { - all_memory_size.emplace_back(size); + batch_all_memory_size[batch_label].emplace_back(size); + if (batch_total_size.find(batch_label) == batch_total_size.end()) { + batch_total_size[batch_label] = size; } else { + batch_total_size[batch_label] += size; + } + + if (!anchor_to_symbol_.empty()) { auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); if (iter1 == anchor_to_symbol_.end()) { continue; @@ -452,23 +602,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { } } temp.clear(); - GetNodeWorkSpaceSize(n, temp); - all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); - } - for (const auto &pair : symbol_size_) { - all_memory_size.emplace_back(pair.second); - } - sort(all_memory_size.begin(), all_memory_size.end()); - GELOGD("All memory size: %s", ToString(all_memory_size).c_str()); - - for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { - if (*iter == 0) { - iter = all_memory_size.erase(iter); - } else { - ++iter; - } + GetNodeWorkSpaceSize(n, temp, batch_total_size[batch_label]); + batch_all_memory_size[batch_label].insert(batch_all_memory_size[batch_label].end(), temp.begin(), temp.end()); } - + GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); + GetMaxBatchAllMemorySize(batch_all_memory_size, batch_total_size, all_memory_size, max_batch_label_); InitReuseFlag(); PrintSymbolMap(); } @@ -529,16 +667,6 @@ bool CanReuseBySize(const map &reusable_block_counts, const Me bool can_reuse = false; if (reusable_block.Size() == block_size) { can_reuse = true; - } else { - string key = std::to_string(reusable_block.Size()); - key += "_" + std::to_string(reusable_block.stream_id_); - key += "_" + std::to_string(reusable_block.memory_type_); - auto it = reusable_block_counts.find(key); - GE_IF_BOOL_EXEC((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && - (reusable_block.Size() > block_size), - can_reuse = true; - GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", - reusable_block.Size(), block_size);); } return can_reuse; } @@ -860,34 +988,35 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); + std::string batch_label; + (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (batch_label.empty() || (batch_label == max_batch_label_)) { + size_t align_size = real_size; + AlignMemOffset(align_size); + theory_memory_size_ += align_size; + if (theory_memory_size_ > theory_min_memory_size_) { + theory_min_memory_size_ = theory_memory_size_; + } + } bool is_reuse_memory = false; - string ge_disable_reuse_mem_env = "0"; - (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); - if (ge_disable_reuse_mem_env != "1") { + if (ge_disable_reuse_mem_env_ != "1") { bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) : !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem; - auto stream_id = node_op_desc->GetStreamId(); - if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) { - for (auto it = reusable_blocks_[memory_type][stream_id].begin(); - it != reusable_blocks_[memory_type][stream_id].end(); ++it) { + bool do_reuse = is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty(); + if (do_reuse) { + auto stream_id = node_op_desc->GetStreamId(); + for (auto it = reusable_blocks_[memory_type][stream_id].rbegin(); + it != reusable_blocks_[memory_type][stream_id].rend(); ++it) { MemoryBlock *reusable_block = *it; if (!IsPostReuse(reusable_block)) { reusable_block->reuse_mem_ = false; GELOGI("Unreusable block."); continue; } - std::string batch_label; - if (reusable_block->IsSameLabel(batch_label)) { - std::string op_label; - (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label); - if (batch_label != op_label) { - GELOGI("label diff, op name %s", node_op_desc->GetName().c_str()); - continue; - } - } + GE_IF_BOOL_EXEC(reusable_block->batch_label_ != batch_label, continue); // A node can reuse blocks of the same stream and preorder streams if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { @@ -901,7 +1030,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, reusable_block->continuous_block_ = continuous; reusable_block->ref_count_++; ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); - reusable_blocks_[memory_type][stream_id].erase(it); + reusable_blocks_[memory_type][stream_id].erase((++it).base()); return reusable_block; } } @@ -914,10 +1043,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, // Data and netoutput need zero copy block block->is_zero_copy_ = IsZeroCopyBlock(n, continuous); - block->Init(real_size, mem_type, n, out_index, no_align_size); + block->Init(real_size, mem_type, n, out_index, no_align_size, node_op_desc->GetStreamId()); block->stream_id_ = node_op_desc->GetStreamId(); block->ref_count_++; block->continuous_block_ = continuous; + block->batch_label_ = batch_label; if (mem_type == kOutput) { auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); if (iter != anchor_to_symbol_.end()) { @@ -945,6 +1075,11 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec return nullptr; } + if (CheckIsZeroMemNodeType(n->GetType())) { + zero_memory_list_.emplace_back(n, kOutput, index); + continue; + } + int64_t size = 0; if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { GELOGI("Get size failed"); @@ -957,9 +1092,7 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec // only apply total size in first block if (index != 0) { zero_memory_list_.emplace_back(n, kOutput, index); - } - - if (index == 0) { + } else { NodeIndexIO node_index_io(n, index, kOut); auto iter = anchor_to_symbol_.find(node_index_io.ToString()); if (iter != anchor_to_symbol_.end()) { @@ -972,6 +1105,10 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec } } + if (total_size == 0) { + return nullptr; + } + auto block_size = GetBlockSize(total_size, ranges); GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), total_size, block_size); @@ -1119,15 +1256,28 @@ bool IsKnownSubgraphData(const NodePtr &node) { return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); } -void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory) { +void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory, + bool same_stream) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); --to_release->ref_count_; + if (!same_stream) { + to_release->same_stream_ = false; + } if (to_release->ref_count_ == 0) { - to_release->SetLifeTimeEnd(life_time_); - reusable_memory.emplace_back(to_release); - AddReusableBlockCount(*to_release, reusable_block_counts_); + if (to_release->reuse_mem_ && !to_release->RealSizeList().empty()) { + if (to_release->batch_label_.empty() || (to_release->batch_label_ == max_batch_label_)) { + size_t align_size = to_release->RealSizeList().back(); + AlignMemOffset(align_size); + theory_memory_size_ -= align_size; + } + } + if (to_release->same_stream_) { + to_release->SetLifeTimeEnd(life_time_); + reusable_memory.emplace_back(to_release); + AddReusableBlockCount(*to_release, reusable_block_counts_); + } } } @@ -1167,10 +1317,9 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_mapGetName().c_str()); if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) && - (node_type_indexs.back().index == static_cast(in_anchor->GetPeerOutAnchor()->GetIdx())) && - (node->GetOpDesc()->GetStreamId() == block->stream_id_)) { - ReleaseMemory(block, reusable_memory); - if (block->ref_count_ == 0) { + (node_type_indexs.back().index == static_cast(in_anchor->GetPeerOutAnchor()->GetIdx()))) { + ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_)); + if (block->ref_count_ == 0 && block->same_stream_) { SetLastUsedInputMemAttr(node, in_anchor->GetIdx()); } } @@ -1267,7 +1416,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType())); if (!no_need_assign_memory) { out_node_set_continuous_input = - IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, no_need_assign_memory, reset_zero_copy_flag); + IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, + no_need_assign_memory, reset_zero_copy_flag); GE_IF_BOOL_EXEC(!no_need_assign_memory, no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);); } @@ -1328,7 +1478,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { iter->second[stream_id].clear(); } vector temp; - GetNodeWorkSpaceSize(n, temp); + int64_t tatal_size = 0; + GetNodeWorkSpaceSize(n, temp, tatal_size); vector workspace_bytes; vector tvm_workspace_memory_type; bool has_tvm_workspace_mem_type_attr = @@ -1349,7 +1500,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { bool workspace_skip_flag = false; if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) { GELOGI( - "fusion: node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", + "fusion:node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]); workspace_skip_flag = true; } @@ -1380,9 +1531,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { (void)mem_block; // Fix warning } - bool merge_dynamic_batch = false; - GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks()); - GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size())); + GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), ReuseBlocksByLifeTime(ranges.size())); AssignContinuousBlocks(); ResizeMemoryBlocks(); @@ -1402,92 +1551,19 @@ void BlockMemAssigner::CheckWorkspaceReuse(const vector &workspace_reuse_f } } -void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector &workspace_memory) { +void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector &workspace_memory, + int64_t &total_size) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null."); vector workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes(); GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size()); for (int64_t byte_size : workspace_byte_nums) { workspace_memory.emplace_back(byte_size); + total_size += byte_size; GELOGD("push back size:%ld", byte_size); } } -// descending order -static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) { - if (left == nullptr || right == nullptr) { - return false; - } - auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end()); - if (left_max_size != left->RealSizeList().end()) { - auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end()); - if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) { - return true; - } - } - return false; -} - -void MergeBlocks(std::vector &dest, std::vector &src) { - for (size_t i = 0; i < dest.size(); ++i) { - if (i >= src.size()) { - return; - } - if (dest[i] != nullptr && src[i] != nullptr) { - if (!dest[i]->reuse_mem_ || !src[i]->reuse_mem_) { - GELOGD("Diff batch's workspace can't be reused, i: %zu, dest[i]: %s, stream: %ld, src[i]: %s, stream: %ld.", - i, dest[i]->String().c_str(), dest[i]->stream_id_, src[i]->String().c_str(), src[i]->stream_id_); - continue; - } - for (auto &symbol : src[i]->SymbolList()) { - dest[i]->AddSymbol(symbol); - } - for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { - dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], - src[i]->RealSizeList()[j], - src[i]->NoAlignSizeList()[j]); - src[i]->deleted_block_ = true; - } - } - } -} - -bool BlockMemAssigner::MergeDynamicBatchBlocks() { - bool merged = false; - std::map> dynamic_batch_blocks; - for (auto block : memory_blocks_) { - if (block == nullptr) { - continue; - } - std::string batch_label; - if (block->IsSameLabel(batch_label)) { - dynamic_batch_blocks[batch_label].emplace_back(block); - } - } - - auto it = dynamic_batch_blocks.begin(); - auto it_max = it; - - // find max block counts - for (; it != dynamic_batch_blocks.end(); ++it) { - if (it->second.size() > it_max->second.size()) { - it_max = it; - } - std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize); - } - if (it_max != dynamic_batch_blocks.end()) { - GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size()); - } - for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) { - if (it != it_max) { - GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str()); - MergeBlocks(it_max->second, it->second); - merged = true; - } - } - return merged; -} - // asending order static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { if (left == nullptr || right == nullptr) { @@ -1597,38 +1673,93 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { } } +void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) { + if (block.memory_type_ == RT_MEMORY_HBM) { + if (block.first_continuous_block_) { + mem_offset += MEM_ALIGN_SIZE; + } + block.Resize(); + block.SetHeadOffset(mem_offset); + mem_offset += block.Size(); + block.SetTailOffset(mem_offset - 1); + } else if (block.memory_type_ == RT_MEMORY_P2P_DDR) { + if (block.first_continuous_block_) { + p2p_mem_offset += MEM_ALIGN_SIZE; + } + block.Resize(); + block.SetHeadOffset(p2p_mem_offset); + p2p_mem_offset += block.Size(); + block.SetTailOffset(p2p_mem_offset - 1); + } +} + +bool DynamicBatchBlockReuse(MemoryBlock &block) { + return (block.IsSameBatchLabel() && block.reuse_mem_); +} + /// /// @ingroup domi_omg -/// @brief traverse memory size, resize, calculate offset +/// @brief get max batch memory size, others reuse this block memory /// @param [in&out] memory_blocks_ memory block, after calculating offset +/// |-dynamic batch block batch1| +/// |-dynamic batch block batch2----| +/// |-dynamic batch block batch3--| /// -void BlockMemAssigner::ResizeMemoryBlocks() { - for (auto &memory_block : memory_blocks_) { - if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) { +void BlockMemAssigner::ResizeDynamicBatchBlocks() { + std::map> dynamic_batch_blocks; + for (auto block : memory_blocks_) { + if (block == nullptr) { continue; } - if (memory_block->memory_type_ == RT_MEMORY_HBM) { - if (memory_block->first_continuous_block_) { - mem_offset_ += MEM_ALIGN_SIZE; - } + // when memory is not reuseable, it can't be reused by different branch + if (DynamicBatchBlockReuse(*block)) { + dynamic_batch_blocks[block->batch_label_].emplace_back(block); + } + } - memory_block->Resize(); - memory_block->SetHeadOffset(mem_offset_); - mem_offset_ += memory_block->Size(); - memory_block->SetTailOffset(mem_offset_ - 1); - } else if (memory_block->memory_type_ == RT_MEMORY_P2P_DDR) { - if (memory_block->first_continuous_block_) { - p2p_mem_offset_ += MEM_ALIGN_SIZE; + size_t max_mem_offset = mem_offset_; + size_t max_p2p_mem_offset = p2p_mem_offset_; + for (auto &batch_blocks : dynamic_batch_blocks) { + size_t mem_offset = mem_offset_; + size_t p2p_mem_offset = p2p_mem_offset_; + for (auto block : batch_blocks.second) { + if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) { + continue; } + AddBlockMemOffset(mem_offset, p2p_mem_offset, *block); + } + if (mem_offset > max_mem_offset) { + max_mem_offset = mem_offset; + } + if (p2p_mem_offset > max_p2p_mem_offset) { + max_p2p_mem_offset = p2p_mem_offset; + } + GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset); + } + mem_offset_ = max_mem_offset; + p2p_mem_offset_ = max_p2p_mem_offset; +} - memory_block->Resize(); - memory_block->SetHeadOffset(p2p_mem_offset_); - p2p_mem_offset_ += memory_block->Size(); - memory_block->SetTailOffset(p2p_mem_offset_ - 1); +/// +/// @ingroup domi_omg +/// @brief traverse memory size, resize, calculate offset +/// @param [in&out] memory_blocks_ memory block, after calculating offset +/// |-not dynamic batch block-||-dynamic batch block batch1| |-zero copy block-| +/// |-not dynamic batch block-||-dynamic batch block batch2----||-zero copy block-| +/// |-not dynamic batch block-||-dynamic batch block batch3--| |-zero copy block-| +/// +void BlockMemAssigner::ResizeMemoryBlocks() { + for (auto &memory_block : memory_blocks_) { + if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_ + || DynamicBatchBlockReuse(*memory_block)) { + continue; } + + AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block); } - GELOGD("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu.", - mem_offset_, p2p_mem_offset_); + ResizeDynamicBatchBlocks(); + GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu," + "theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_); } /// @@ -1641,7 +1772,7 @@ void BlockMemAssigner::ResizeMemoryBlocks() { /// @return Status result /// void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, - size_t real_size, size_t no_align_size, bool child_block) { + size_t real_size, size_t no_align_size, int32_t child_block_level) { ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); @@ -1689,14 +1820,15 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, } op_desc->SetWorkspace(workspace_list); } - GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" - " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(), + GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu] noalignsize[%zu] " + "life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), - block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, block->reuse_mem_, - block->continuous_block_, block->deleted_block_, node_type.ref_input); + block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block_level, block->reuse_mem_, + block->continuous_block_, block->is_zero_copy_, block->same_stream_, node_type.ref_input, + block->batch_label_.c_str()); } -void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { +void SetBlockOpMemOffset(MemoryBlock *block, int32_t child_block_level) { if (block == nullptr) { return; } @@ -1709,9 +1841,14 @@ void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { real_size = block->RealSizeList()[index]; no_align_size = block->NoAlignSizeList()[index]; } - SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block); + SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block_level); index++; } + + child_block_level++; + for (MemoryBlock *child_block : block->ChildBlockList()) { + SetBlockOpMemOffset(child_block, child_block_level); + } } void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { @@ -1724,16 +1861,13 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { continue; } - SetBlockOpMemOffset(memory_block, false); - for (MemoryBlock *child_block : memory_block->ChildBlockList()) { - SetBlockOpMemOffset(child_block, true); - } + SetBlockOpMemOffset(memory_block, 0); } if (!is_zero_copy) { for (const NodeTypeIndex &node_type_index : zero_memory_list_) { MemoryBlock block(0, 0); - SetOffsetSize(node_type_index, &block, 0, 0, false); + SetOffsetSize(node_type_index, &block, 0, 0, 0); } } } diff --git a/ge/graph/build/memory/block_mem_assigner.h b/ge/graph/build/memory/block_mem_assigner.h old mode 100755 new mode 100644 index f3d26c1d..d514ca34 --- a/ge/graph/build/memory/block_mem_assigner.h +++ b/ge/graph/build/memory/block_mem_assigner.h @@ -65,6 +65,7 @@ class MemoryBlock { stream_id_(stream_id), deleted_block_(false), reuse_mem_(reuse_mem), + same_stream_(true), input_index_(0), continuous_block_(false), first_continuous_block_(false), @@ -85,10 +86,14 @@ class MemoryBlock { symbol_list_.clear(); } - void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) { + void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size, + int64_t stream_id) { real_size_list_.emplace_back(real_size); no_align_size_list_.emplace_back(no_align_size); node_type_index_list_.emplace_back(node, type, out_index, false); + if (stream_id != stream_id_) { + same_stream_ = false; + } } size_t Size() const { return block_size_; } @@ -106,6 +111,12 @@ class MemoryBlock { node_type_index_list_.emplace_back(node_type_index); real_size_list_.emplace_back(real_size); no_align_size_list_.emplace_back(no_align_size); + if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) { + auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId(); + if (stream_id != stream_id_) { + same_stream_ = false; + } + } } void AddSymbol(const std::string &symbol) { @@ -122,7 +133,7 @@ class MemoryBlock { std::string String(); - bool IsSameLabel(std::string &first_batch_label); + bool IsSameBatchLabel(); void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); @@ -142,6 +153,7 @@ class MemoryBlock { int64_t stream_id_; bool deleted_block_; bool reuse_mem_; + bool same_stream_; uint32_t input_index_; bool continuous_block_; bool first_continuous_block_; @@ -149,6 +161,7 @@ class MemoryBlock { bool is_zero_copy_; std::map depend_stream_life_; int64_t memory_type_; + std::string batch_label_; private: size_t block_size_; std::vector real_size_list_; @@ -209,7 +222,7 @@ class BlockMemAssigner : public MemAssigner { void GetOutAndWorkSpaceMem(std::vector &all_memory_size); - void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector &workspace_memory); + void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector &workspace_memory, int64_t &total_size); /// /// @ingroup GE @@ -353,7 +366,7 @@ class BlockMemAssigner : public MemAssigner { /// @return void /// @author /// - void ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory); + void ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory, bool same_stream = true); /// /// @ingroup GE @@ -379,11 +392,11 @@ class BlockMemAssigner : public MemAssigner { /// /// @ingroup GE - /// @brief Merge memory blocks between different batchs + /// @brief Resize memory blocks for each batchs /// @return merge or not /// @author /// - bool MergeDynamicBatchBlocks(); + void ResizeDynamicBatchBlocks(); void AssignContinuousBlocks(); @@ -436,6 +449,17 @@ class BlockMemAssigner : public MemAssigner { int64_t atomic_addr_clean_id_ = 0; + size_t theory_min_memory_size_ = 0; + + size_t theory_memory_size_ = 0; + + std::string max_batch_label_; + + /// + /// @ [stream1][nodeid] + /// @[nodeid] [stream2][nodeid] + /// @ [stream2][nodeid] + /// DependStreamLife total_node_depend_stream_life_; }; } // namespace ge diff --git a/ge/graph/build/memory/graph_mem_assigner.cc b/ge/graph/build/memory/graph_mem_assigner.cc old mode 100755 new mode 100644 index ad0235d5..16d5d38f --- a/ge/graph/build/memory/graph_mem_assigner.cc +++ b/ge/graph/build/memory/graph_mem_assigner.cc @@ -419,7 +419,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + - " requires continuous output. There may be conflict between the two. This node is not supported now."; + " requires continuous output. There may be conflict between the two." + + "This node is not supported now."; GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); return PARAM_INVALID;); @@ -429,7 +430,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, GE_IF_BOOL_EXEC(is_peer_reference, std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + - " requires continuous output. There may be conflict between the two. This node is not supported now."; + " requires continuous output. There may be conflict between the two." + + "This node is not supported now."; GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); return PARAM_INVALID;); @@ -1646,9 +1648,9 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &node, const ve } string atomic_mem_size_str = ss.str(); - GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]", + GELOGI("[IMAS]SetAtomicCleanAttr : Set %s atomic_node name[%s] output[0] offset to [%s] streamid[%ld] size[%s]", node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), - atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId()); + atomic_mem_start_str.c_str(), node->GetOpDesc()->GetStreamId(), atomic_mem_size_str.c_str()); } return SUCCESS; } diff --git a/ge/graph/build/memory/graph_mem_assigner.h b/ge/graph/build/memory/graph_mem_assigner.h old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/hybrid_mem_assigner.cc b/ge/graph/build/memory/hybrid_mem_assigner.cc old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/hybrid_mem_assigner.h b/ge/graph/build/memory/hybrid_mem_assigner.h old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/mem_assigner.h b/ge/graph/build/memory/mem_assigner.h old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/memory_assigner.cc b/ge/graph/build/memory/memory_assigner.cc old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/module.mk b/ge/graph/build/memory/module.mk old mode 100755 new mode 100644 diff --git a/ge/graph/build/memory/var_mem_assign_util.cc b/ge/graph/build/memory/var_mem_assign_util.cc old mode 100755 new mode 100644 diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc old mode 100755 new mode 100644 index d7039cfb..37eb499a --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -282,7 +282,7 @@ Status ModelBuilder::SetInputOutputDesc() { void ModelBuilder::AddNodeInputProperty() { for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); + GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return); vector src_name_list; vector src_index_list; for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { @@ -309,10 +309,10 @@ void ModelBuilder::AddNodeInputProperty() { for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); + GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return); GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); auto out_control_anchor = node->GetOutControlAnchor(); - GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return ); + GE_IF_BOOL_EXEC(out_control_anchor == nullptr, GELOGW("out_control_anchor is nullptr"); return); vector dst_name_list; vector dst_index_list; string dst_name_temp; @@ -330,7 +330,7 @@ void ModelBuilder::AddNodeInputProperty() { dst_name_temp = ""; int64_t dst_index = kWrongIndex; // assign an impossible value to dst_index. for (const auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return ); + GE_IF_BOOL_EXEC(in_data_anchor == nullptr, GELOGW("in_data_anchor is nullptr"); return); ge::NodePtr dst_node = in_data_anchor->GetOwnerNode(); dst_name_temp = dst_name_temp.empty() ? dst_node->GetName() : dst_name_temp + ":" + dst_node->GetName(); dst_index = in_data_anchor->GetIdx(); diff --git a/ge/graph/build/run_context.h b/ge/graph/build/run_context.h old mode 100755 new mode 100644 diff --git a/ge/graph/build/stream_allocator.cc b/ge/graph/build/stream_allocator.cc index 4378f71b..a1cda506 100644 --- a/ge/graph/build/stream_allocator.cc +++ b/ge/graph/build/stream_allocator.cc @@ -49,7 +49,8 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string & } bool IsHcclOp(const string &op_type) { - const set hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); + const set hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, + ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); return hccl_op_types.find(op_type) != hccl_op_types.end(); } } // namespace diff --git a/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc index 582c080b..2933d413 100644 --- a/ge/graph/build/stream_graph_optimizer.cc +++ b/ge/graph/build/stream_graph_optimizer.cc @@ -38,7 +38,7 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap continue; } for (ge::NodePtr &node : subgraph->GetDirectNode()) { - GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); + GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return); if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { node->GetOpDesc()->SetId(static_cast(node_size)); node_size++; diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc old mode 100755 new mode 100644 index 41607f1f..b506f945 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -49,8 +49,6 @@ const char *const kIsLastNode = "is_last_node"; const char *const kIsInputVar = "INPUT_IS_VAR"; const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; -const char *const kProfilingFpPoint = "FP_POINT"; -const char *const kProfilingBpPoint = "BP_POINT"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -810,35 +808,23 @@ Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint vector &all_reduce_nodes, std::string &fp_point_str, std::string &bp_point_str) const { - if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_FPPONIT_OPTIONS, fp_point_str) == SUCCESS && - ge::GetContext().GetOption(OPTION_EXEC_PROFILING_BPPONIT_OPTIONS, bp_point_str) == SUCCESS && - !fp_point_str.empty() && !bp_point_str.empty()) { - return SUCCESS; - } + ProfilingManager::Instance().GetFpBpPoint(fp_point_str, bp_point_str); Status ret = SUCCESS; - const char *fp_point = std::getenv(kProfilingFpPoint); - if (fp_point == nullptr) { + if (fp_point_str.empty()) { ret = AutoFindFpOpIndex(graph, profiling_point); if (ret != SUCCESS) { GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); return FAILED; } - } else { - fp_point_str = string(fp_point); - GELOGI("Get fp_point_str from env %s", fp_point_str.c_str()); } - const char *bp_point = std::getenv(kProfilingBpPoint); - if (bp_point == nullptr) { + if (bp_point_str.empty()) { ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); if (ret != SUCCESS) { GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); return FAILED; } - } else { - bp_point_str = string(bp_point); - GELOGI("Get bp_point_str from env %s", bp_point_str.c_str()); } return SUCCESS; diff --git a/ge/graph/build/task_generator.h b/ge/graph/build/task_generator.h old mode 100755 new mode 100644 diff --git a/ge/graph/common/transop_util.cc b/ge/graph/common/transop_util.cc old mode 100755 new mode 100644 diff --git a/ge/graph/execute/graph_execute.cc b/ge/graph/execute/graph_execute.cc old mode 100755 new mode 100644 diff --git a/ge/graph/execute/graph_execute.h b/ge/graph/execute/graph_execute.h old mode 100755 new mode 100644 diff --git a/ge/graph/label/case_label_maker.h b/ge/graph/label/case_label_maker.h index 1078a906..3dbfb2bc 100644 --- a/ge/graph/label/case_label_maker.h +++ b/ge/graph/label/case_label_maker.h @@ -86,7 +86,6 @@ | Node | +------------+ *******************************************************************************/ - namespace ge { class CaseOpLabelMaker : public LabelMaker { public: diff --git a/ge/graph/label/if_label_maker.h b/ge/graph/label/if_label_maker.h index 0807f549..8b07eb96 100644 --- a/ge/graph/label/if_label_maker.h +++ b/ge/graph/label/if_label_maker.h @@ -70,7 +70,6 @@ | Node | +------------+ *******************************************************************************/ - namespace ge { class IfOpLabelMaker : public LabelMaker { public: diff --git a/ge/graph/label/partitioned_call_label_maker.h b/ge/graph/label/partitioned_call_label_maker.h index b89cb94c..3944aabd 100644 --- a/ge/graph/label/partitioned_call_label_maker.h +++ b/ge/graph/label/partitioned_call_label_maker.h @@ -54,7 +54,6 @@ | c | +---------------+ *******************************************************************************/ - namespace ge { class PartitionedCallLabelMaker : public LabelMaker { public: diff --git a/ge/graph/label/while_label_maker.h b/ge/graph/label/while_label_maker.h index 0eb0deee..6c30475b 100644 --- a/ge/graph/label/while_label_maker.h +++ b/ge/graph/label/while_label_maker.h @@ -70,7 +70,6 @@ | Node | +------------+ *******************************************************************************/ - namespace ge { class WhileOpLabelMaker : public LabelMaker { public: diff --git a/ge/graph/load/graph_loader.cc b/ge/graph/load/graph_loader.cc old mode 100755 new mode 100644 index aa825a5d..44556422 --- a/ge/graph/load/graph_loader.cc +++ b/ge/graph/load/graph_loader.cc @@ -283,7 +283,8 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn std::vector &output_desc) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc); + Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, + input_data, input_desc, output_data, output_desc); if (ret != SUCCESS) { GELOGE(ret, "Execute model failed, model_id:%u.", model_id); return ret; diff --git a/ge/graph/load/graph_loader.h b/ge/graph/load/graph_loader.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/aipp_utils.cc b/ge/graph/load/new_model_manager/aipp_utils.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/aipp_utils.h b/ge/graph/load/new_model_manager/aipp_utils.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/data_dumper.cc b/ge/graph/load/new_model_manager/data_dumper.cc index 4534fe73..b331d780 100644 --- a/ge/graph/load/new_model_manager/data_dumper.cc +++ b/ge/graph/load/new_model_manager/data_dumper.cc @@ -919,11 +919,11 @@ Status DataDumper::DumpExceptionInfo(const std::vector exceptio ReplaceStringElem(op_name); ReplaceStringElem(op_type); string dump_file_path = - "./" + op_type + "." + op_name + "." + to_string(op_desc_info.task_id) + "." + to_string(now_time); + "./" + op_type + "." + op_name + "." + std::to_string(op_desc_info.task_id) + "." + std::to_string(now_time); GELOGI("The exception dump file path is %s", dump_file_path.c_str()); uint64_t proto_size = dump_data.ByteSizeLong(); - unique_ptr proto_msg(new (std::nothrow) char[proto_size]); + std::unique_ptr proto_msg(new (std::nothrow) char[proto_size]); bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); if (!ret || proto_size == 0) { GELOGE(PARAM_INVALID, "Dump data proto serialize failed"); diff --git a/ge/graph/load/new_model_manager/data_dumper.h b/ge/graph/load/new_model_manager/data_dumper.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/data_inputer.cc b/ge/graph/load/new_model_manager/data_inputer.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/data_inputer.h b/ge/graph/load/new_model_manager/data_inputer.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/davinci_model.cc b/ge/graph/load/new_model_manager/davinci_model.cc old mode 100755 new mode 100644 index 81d47b3b..bc755e07 --- a/ge/graph/load/new_model_manager/davinci_model.cc +++ b/ge/graph/load/new_model_manager/davinci_model.cc @@ -16,7 +16,6 @@ #include "graph/load/new_model_manager/davinci_model.h" -#include #include #include #include @@ -84,7 +83,7 @@ const uint32_t kAddrLen = sizeof(void *); const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; -const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; +const uint32_t kDumpL1FusionOpMByteSize = 2097152; // 2 * 1024 * 1024 const uint32_t kDumpFlagOfL1Fusion = 0; const char *const kDefaultBatchLable = "Batch_default"; const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node"; @@ -331,8 +330,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; } - GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, - mem_base_, data_size); + GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", + runtime_param_.graph_id, mem_base_, data_size); if (!is_inner_weight_base_) { weights_mem_base_ = mem_base_; @@ -713,7 +712,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size // collect profiling for ge auto &profiling_manager = ProfilingManager::Instance(); if (profiling_manager.ProfilingModelLoadOn()) { - Status p_ret = ReportProfilingData(!profiling_manager.IsAclApiMode()); + Status p_ret = ReportProfilingData(); if (p_ret != SUCCESS) { GELOGE(p_ret, "Report profiling data failed."); return p_ret; @@ -724,14 +723,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size return ret; } -Status DavinciModel::ReportProfilingData(bool check_device) { +Status DavinciModel::ReportProfilingData() { std::vector compute_graph_desc_info; Status ret = GetComputeGraphInfo(compute_graph_desc_info); if (ret != SUCCESS) { GELOGE(ret, "GetComputeGraphInfo failed."); return ret; } - ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info, check_device); + ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info); GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); op_list_.clear(); @@ -1544,7 +1543,8 @@ Status DavinciModel::LoadWithQueue() { } if (output_queue_ids_.size() != new_output_data_info_.size()) { - GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", + GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, + "Output queue ids not match model: output_queue=%zu output_data=%zu", output_queue_ids_.size(), new_output_data_info_.size()); return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID; } @@ -2186,8 +2186,9 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data const std::vector &blobs = input_data.blobs; for (const auto &data : new_input_data_info_) { if (data.first >= blobs.size()) { - GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), - new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first); + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld, op_name(%s)", blobs.size(), + new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first, + data.second.GetOpName().c_str()); return FAILED; } @@ -2198,13 +2199,14 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data } uint64_t data_size = data.second.GetDataSize(); GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID, - "input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length, - data_size); + "input data size(%lu) does not match model required size(%lu), op_name(%s) ret failed.", + data_buf.length, data_size, data.second.GetOpName().c_str()); void *mem_addr = data.second.GetBasicAddr(); void *data_buf_addr = reinterpret_cast(reinterpret_cast(data_buf.data)); uint64_t data_buf_length = data_buf.length; - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", - runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length); + GELOGI("CopyPlainData memcpy graph_%u type[F] input[%s] rank[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", + runtime_param_.graph_id, data.second.GetOpName().c_str(), data.first, mem_addr, data_buf_addr, data_size, + data_buf_length); GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); } @@ -2248,10 +2250,8 @@ inline int64_t SumSize(const vector &size_list) { Status DavinciModel::SinkModelProfile() { // profiling plugin must be registered - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); - - Msprof::Engine::ReporterData reporter_data{}; + auto &prof_mgr = ProfilingManager::Instance(); + ReporterData reporter_data{}; // report model data tag name std::string tag_name; tag_name.append("model_load_info_").append(std::to_string(this->Id())); @@ -2269,32 +2269,32 @@ Status DavinciModel::SinkModelProfile() { reporter_data.deviceId = device_id_; reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); reporter_data.data = (unsigned char *)name.c_str(); reporter_data.dataLen = name.size(); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); uint32_t model_id = this->Id(); reporter_data.data = (unsigned char *)&model_id; reporter_data.dataLen = sizeof(uint32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); // Load Start/End Time int64_t start_time = this->GetLoadBeginTime(); reporter_data.data = (unsigned char *)&start_time; reporter_data.dataLen = sizeof(int64_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); int64_t end_time = this->GetLoadEndTime(); reporter_data.data = (unsigned char *)&end_time; reporter_data.dataLen = sizeof(int64_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); int32_t task_num = task_list_.size(); std::multimap op_id_map; @@ -2308,6 +2308,7 @@ Status DavinciModel::SinkModelProfile() { uint32_t op_num = fusion_op_info->original_op_names.size(); uint32_t task_id = task->GetTaskID(); if (op_num > 0) { + GELOGI("task.id = %u, opNum = %u", task_id, op_num); op_id_map.insert(std::make_pair(fusion_op_info->op_index, task_id)); } } @@ -2350,39 +2351,39 @@ Status DavinciModel::SinkModelProfile() { int32_t fusion_op_name_len = fusion_op_name.size() == 0 ? 1 : fusion_op_name.size(); reporter_data.data = (unsigned char *)&fusion_op_name_len; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); reporter_data.data = (unsigned char *)fusion_op_name.c_str(); reporter_data.dataLen = fusion_op_name_len; - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); // original op name before fusion reporter_data.data = (unsigned char *)&op_num; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); for (uint32_t k = 0; k < op_num; k++) { std::string op_name = fusion_op_info->original_op_names[k]; int32_t op_name_len = op_name.size() == 0 ? 1 : op_name.size(); reporter_data.data = (unsigned char *)&op_name_len; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); reporter_data.data = (unsigned char *)op_name.c_str(); reporter_data.dataLen = op_name_len; - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); } // stream id info uint32_t streamId = task->GetStreamId(); reporter_data.data = (unsigned char *)&streamId; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); // memory info struct memoryInfo memory_info; @@ -2398,22 +2399,22 @@ Status DavinciModel::SinkModelProfile() { memory_info.weight_size + memory_info.input_size + memory_info.output_size + memory_info.workspace_size; reporter_data.data = (unsigned char *)&memory_info; reporter_data.dataLen = sizeof(struct memoryInfo); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); // task info reporter_data.data = (unsigned char *)&task_count; reporter_data.dataLen = sizeof(uint32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); Range task_range = op_id_map.equal_range(op_id); for (CIT idx = task_range.first; idx != task_range.second; ++idx) { uint32_t task_id = idx->second; reporter_data.data = (unsigned char *)&task_id; reporter_data.dataLen = sizeof(uint32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); } } } @@ -2422,10 +2423,8 @@ Status DavinciModel::SinkModelProfile() { Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { // profiling plugin must be registered - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); - - Msprof::Engine::ReporterData reporter_data{}; + auto &prof_mgr = ProfilingManager::Instance(); + ReporterData reporter_data{}; // report model data tag name std::string tag_name; tag_name.append("model_time_info_") @@ -2448,33 +2447,33 @@ Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { size_t name_len = name.size(); reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); reporter_data.data = (unsigned char *)name.c_str(); reporter_data.dataLen = name.size(); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - this->Id()); + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, + "Reporter data fail, model id:%u.", this->Id()); // request id uint64_t request_id = current_data.request_id; reporter_data.data = (unsigned char *)&request_id; reporter_data.dataLen = sizeof(uint32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); // thread id int32_t thread_id = GetDataInputTid(); reporter_data.data = (unsigned char *)&thread_id; reporter_data.dataLen = sizeof(int32_t); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); // time info time_info_.modelId = this->Id(); reporter_data.data = (unsigned char *)&time_info_; reporter_data.dataLen = sizeof(struct timeInfo); - GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, + GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, "Reporter data fail, model id:%u, data index:%u.", this->Id(), current_data.index); return SUCCESS; @@ -2696,8 +2695,9 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b is_getnext_sink_dynamic_ = true; cur_dynamic_dims_.clear(); cur_dynamic_dims_.resize(shape_of_cur_dynamic_dims_); - GE_CHK_RT_RET(rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), - netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST)); + auto ret = rtMemcpy(cur_dynamic_dims_.data(), shape_of_cur_dynamic_dims_ * sizeof(int64_t), + netoutput_last_input_addr_, netoutput_last_input_size_, RT_MEMCPY_DEVICE_TO_HOST); + GE_CHK_RT_RET(ret); } GELOGD("Cur dynamic dims is %s.", formats::JoinToString(cur_dynamic_dims_).c_str()); if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { @@ -2801,76 +2801,42 @@ void *DavinciModel::Run(DavinciModel *model) { reinterpret_cast(shape_data_buffer_data) + shape_data_buffer_length / sizeof(int64_t)); GELOGD("Data: cur dynamic dims is %s", formats::JoinToString(model->cur_dynamic_dims_).c_str()); - delete[] (int64_t *)current_data.blobs.back().data; + delete[] reinterpret_cast(current_data.blobs.back().data); current_data.blobs.pop_back(); } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_PRE_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); - if (ProfilingManager::Instance().ProfilingOpTraceOn()) { - GELOGI("GetOpTraceIterNum:%d", ProfilingManager::Instance().GetOpTraceIterNum()); - for (int32_t i = 0; i < ProfilingManager::Instance().GetOpTraceIterNum(); i++) { - if (!ProfilingManager::Instance().ProfilingLoadFlag()) { - vector prof_device_id_vec = ProfilingManager::Instance().GetProfilingDeviceId(); - for (size_t j = 0; j < prof_device_id_vec.size(); ++j) { - // just profiling, no need to check value - (void)ProfilingManager::Instance().StartProfiling(i, prof_device_id_vec[j]); - } - } - - GELOGI("rtModelExecute start."); - rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; - (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); - continue); // [No need to check value] - GELOGI("rtModelExecute end"); - - GELOGI("rtStreamSynchronize start."); - rt_ret = rtStreamSynchronize(model->rt_model_stream_); - if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { - GELOGI("The model with multiple datasets aborts normally."); - } else { - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; - (void)model->ReturnResult(current_data.index, false, seq_end_flag, data_wrapper->GetOutput()); - continue); // [No need to check value] - } - - GELOGI("rtStreamSynchronize end."); - (void)ProfilingManager::Instance().StopProfiling(); // just profiling, no need to check value - } + GE_TIMESTAMP_START(rtModelExecute); + GELOGI("rtModelExecute start."); + rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; + (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); + CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); + continue); + GELOGI("rtModelExecute end"); + GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); + + GE_TIMESTAMP_START(rtStreamSynchronize); + GELOGI("rtStreamSynchronize start."); + rt_ret = rtStreamSynchronize(model->rt_model_stream_); + if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) { + seq_end_flag = true; + } + if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { + GELOGI("The model with multiple datasets aborts normally."); } else { - GE_TIMESTAMP_START(rtModelExecute); - GELOGI("rtModelExecute start."); - rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; - (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); - CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); - continue); - GELOGI("rtModelExecute end"); - GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); - - GE_TIMESTAMP_START(rtStreamSynchronize); - GELOGI("rtStreamSynchronize start."); - rt_ret = rtStreamSynchronize(model->rt_model_stream_); - if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) { - seq_end_flag = true; - } - if (rt_ret == kModelAbortNormal || rt_ret == kModelAbortNormalNew) { - GELOGI("The model with multiple datasets aborts normally."); - } else { - GE_IF_BOOL_EXEC( - rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag); - (void)model->ReturnResult(current_data.index, false, seq_end_flag, - data_wrapper->GetOutput()); // [No need to check value] - CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); - continue); - } - - GELOGI("rtStreamSynchronize end."); - GE_IF_BOOL_EXEC(model->is_first_execute_, - GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); - GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END)); + GE_IF_BOOL_EXEC( + rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag); + (void)model->ReturnResult(current_data.index, false, seq_end_flag, + data_wrapper->GetOutput()); // [No need to check value] + CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); + continue); } + GELOGI("rtStreamSynchronize end."); + GE_IF_BOOL_EXEC(model->is_first_execute_, + GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); + GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_AFTER_PROC_START)); GE_TIMESTAMP_START(ReturnResult3); @@ -3170,21 +3136,29 @@ Status DavinciModel::DistributeTask() { const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { + auto &task_def = model_task_def->task(task_index); auto &task = task_list_.at(task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); // for data dump - auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), - model_task_def->task(task_index).kernel_ex().op_index()); + auto op_index = std::max(task_def.kernel().context().op_index(), + task_def.kernel_ex().op_index()); OpDescPtr op = GetOpByIndex(op_index); GE_CHECK_NOTNULL(op); - SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); if (call_dump || is_op_debug_reg_) { SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); } } + + auto task_type = static_cast(task_def.type()); + bool no_need_profiling = (task_type != RT_MODEL_TASK_KERNEL) + && (task_type != RT_MODEL_TASK_KERNEL_EX) + && (task_type != RT_MODEL_TASK_HCCL); + GE_IF_BOOL_EXEC(no_need_profiling, continue); + + SaveDumpOpInfo(runtime_param_, op, task->GetTaskID(), task->GetStreamId()); // Load task info for profiling TaskDescInfo task_desc_info; if (!om_name_.empty()) { @@ -3193,7 +3167,7 @@ Status DavinciModel::DistributeTask() { task_desc_info.model_name = name_; } task_desc_info.op_name = op->GetName(); - task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); + task_desc_info.block_dim = task_def.kernel().block_dim(); task_desc_info.task_id = task->GetTaskID(); task_desc_info.stream_id = task->GetStreamId(); task_desc_info_.emplace_back(task_desc_info); @@ -3391,14 +3365,14 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 /// Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { - GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed."); - return PARAM_INVALID; + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update input data to model failed."); + return ACL_ERROR_GE_PARAM_INVALID; } if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { - GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed."); - return PARAM_INVALID; + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update output data to model failed."); + return ACL_ERROR_GE_PARAM_INVALID; } for (ZeroCopyTask &task : zero_copy_tasks_) { @@ -3444,7 +3418,7 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map & } if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { - GELOGE(FAILED, "Check input size and model size failed"); + GELOGE(FAILED, "Check input size and model size failed, op[%s]", data.second.GetOpName().c_str()); return FAILED; } @@ -3861,7 +3835,8 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa if (!is_async_mode_) { GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START)); ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ACL_ERROR_GE_INTERNAL_ERROR, + "Copy Output data to user failed."); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END)); } @@ -4061,7 +4036,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { data_dumper_.SetDeviceId(device_id); // set loop count addr - auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void * { + auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void *{ if (op != nullptr) { auto v_output_size = ModelUtils::GetOutputSize(op); auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op); diff --git a/ge/graph/load/new_model_manager/davinci_model.h b/ge/graph/load/new_model_manager/davinci_model.h old mode 100755 new mode 100644 index 650f19eb..19888e1f --- a/ge/graph/load/new_model_manager/davinci_model.h +++ b/ge/graph/load/new_model_manager/davinci_model.h @@ -440,7 +440,7 @@ class DavinciModel { Status SinkTimeProfile(const InputData ¤t_data); - Status ReportProfilingData(bool check_device = true); + Status ReportProfilingData(); void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); diff --git a/ge/graph/load/new_model_manager/davinci_model_parser.h b/ge/graph/load/new_model_manager/davinci_model_parser.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/model_manager.cc b/ge/graph/load/new_model_manager/model_manager.cc old mode 100755 new mode 100644 index d6cdf42d..4c2d4530 --- a/ge/graph/load/new_model_manager/model_manager.cc +++ b/ge/graph/load/new_model_manager/model_manager.cc @@ -40,9 +40,7 @@ const int kCmdParSize = 2; const int kDumpCmdPairSize = 2; const std::size_t kProfCmdParaMaxSize = 1000; const std::size_t kProfStartCmdParaSize = 2; -const std::string kCmdTypeProfile = "profile"; const std::string kCmdTypeDump = "dump"; -const std::string kCmdTypeProfiling = "profiling"; const std::string kCmdTypeProfInit = "prof_init"; const std::string kCmdTypeProfFinalize = "prof_finalize"; const std::string kCmdTypeProfStart = "prof_start"; @@ -51,6 +49,9 @@ const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe"; const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; const char *const kDeleteCustOp = "deleteCustOp"; +const int kTimeSpecNano = 1000000000; +const int kTimeSpecMiro = 1000000; +const int kSessionMaxBias = 100; struct CustAicpuSoBuf { uint64_t kernelSoBuf; uint32_t kernelSoBufLen; @@ -224,7 +225,7 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); - std::lock_guard lock(sess_ids_mutex_); + std::lock_guard lock(map_mutex_); std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_KERNEL_DESTROY, session_id, model_id); @@ -237,7 +238,7 @@ ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_ } ge::Status ModelManager::CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id) { - std::lock_guard lock(sess_ids_mutex_); + std::lock_guard lock(map_mutex_); std::vector v_aicpu_kernel; std::string model_key = std::to_string(session_id) + "_" + std::to_string(model_id); if (model_aicpu_kernel_.find(model_key) != model_aicpu_kernel_.end()) { @@ -345,7 +346,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + + davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); } while (0); @@ -629,8 +630,7 @@ Status ModelManager::Stop(uint32_t model_id) { /// Status ModelManager::HandleCommand(const Command &command) { static const std::map> cmds = { - {kCmdTypeProfile, HandleProfileCommand}, {kCmdTypeDump, HandleDumpCommand}, - {kCmdTypeProfiling, HandleAclProfilingCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, + {kCmdTypeDump, HandleDumpCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, {kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, {kCmdTypeProfStop, HandleProfStopCommand}, {kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, @@ -645,21 +645,6 @@ Status ModelManager::HandleCommand(const Command &command) { } } -Status ModelManager::HandleAclProfilingCommand(const Command &command) { - if (command.cmd_params.size() < kCmdParSize) { - GELOGE(PARAM_INVALID, "When the cmd_type is 'profiling', the size of cmd_params must larger than 2."); - return PARAM_INVALID; - } - - std::string map_key = command.cmd_params[0]; - std::string value = command.cmd_params[1]; - if (map_key == PROFILE_CONFIG) { - ProfilingManager::Instance().SetProfilingConfig(value); - } - - return SUCCESS; -} - Status ModelManager::GetModelByCmd(const Command &command, std::shared_ptr &davinci_model) { if (command.cmd_params.size() < kCmdParSize) { @@ -806,29 +791,6 @@ Status ModelManager::HandleProfStopCommand(const Command &command) { return SUCCESS; } -Status ModelManager::HandleProfileCommand(const Command &command) { - if (command.cmd_params.size() < kCmdParSize) { - GELOGE(PARAM_INVALID, "When the cmd_type is 'profile', the size of cmd_params must larger than 2."); - return PARAM_INVALID; - } - - std::string map_key = command.cmd_params[0]; - std::string value = command.cmd_params[1]; - - GELOGI("Profiling mode, Command key:%s , value:%s ", map_key.c_str(), value.c_str()); - - auto iter = PROFILE_COMPONENT_MAP.find(map_key); - if (iter != PROFILE_COMPONENT_MAP.end()) { - std::string property_value = (value == "on") ? "1" : "0"; - PropertiesManager::Instance().SetPropertyValue(iter->second, property_value); - } - - if ((map_key == PROFILER_JOBCTX || map_key == PROFILER_TARGET_PATH || map_key == RTS_PROFILE_PATH)) { - PropertiesManager::Instance().SetPropertyValue(map_key, value); - } - return SUCCESS; -} - static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) { auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key); if (iter != command.cmd_params.end()) { @@ -1072,12 +1034,12 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { GELOGE(INTERNAL_ERROR, "Failed to get current time."); return INTERNAL_ERROR; } - session_id = static_cast(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us + session_id = static_cast(tv.tv_sec * kTimeSpecMiro + tv.tv_usec); // 1000000us session_id_bias_++; // max bais 100. - session_id_bias_ = session_id_bias_ % 100; - session_id = session_id * 100 + session_id_bias_; + session_id_bias_ = session_id_bias_ % kSessionMaxBias; + session_id = session_id * kSessionMaxBias + session_id_bias_; GELOGD("Generate new session id: %lu.", session_id); return SUCCESS; @@ -1086,8 +1048,7 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr listener, void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { GE_CHK_BOOL_RET_STATUS(model.key.empty() || mmAccess2(model.key.c_str(), M_F_OK) == EN_OK, - ACL_ERROR_GE_PARAM_INVALID, - "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); + ACL_ERROR_GE_PARAM_INVALID, "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); GenModelId(&model_id); shared_ptr davinci_model = nullptr; @@ -1148,7 +1109,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model GELOGI("Parse model %u success.", model_id); - davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + + davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); @@ -1252,7 +1213,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy } std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, + "Invalid model id %u, check weather model has been loaded or not.", model_id); if (davinci_model->NeedDestroyAicpuKernel()) { GELOGI("Start to destroy specified aicpu kernel."); @@ -1289,13 +1251,13 @@ Status ModelManager::CreateAicpuSession(uint64_t session_id) { return SUCCESS; } -Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name) { - GELOGI("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str()); +Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded) { + GELOGD("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str()); std::lock_guard lock(cust_aicpu_mutex_); CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); if (aicpu_kernel == nullptr) { - GELOGE(INTERNAL_ERROR, "cust aicpu op %s can't find kernel!", op_desc->GetName().c_str()); - return INTERNAL_ERROR; + GELOGI("cust aicpu op %s has no corresponding kernel!", op_desc->GetName().c_str()); + return SUCCESS; } // get current context @@ -1313,18 +1275,24 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ std::map new_so_name; new_so_name.insert({so_name, aicpu_kernel}); cust_aicpu_so_[resource_id] = new_so_name; - GELOGI("LoadCustAicpuSo new aicpu so resource id %lu", resource_id); + loaded = false; + GELOGD("LoadCustAicpuSo new aicpu so name %s, resource id %lu", so_name.c_str(), resource_id); return SUCCESS; } auto it_so_name = it->second.find(so_name); if (it_so_name == it->second.end()) { it->second.insert({so_name, aicpu_kernel}); - GELOGI("LoadCustAicpuSo add aicpu so resource id %lu", resource_id); + loaded = false; + GELOGD("LoadCustAicpuSo add aicpu so name %s, resource id %lu", so_name.c_str(), resource_id); + return SUCCESS; } + loaded = true; + GELOGD("LoadCustAicpuSo so name %s has been loaded.", so_name.c_str()); return SUCCESS; } Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { + GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); std::lock_guard lock(cust_aicpu_mutex_); if (cust_aicpu_so_.size() == 0) return SUCCESS; // get current context diff --git a/ge/graph/load/new_model_manager/model_manager.h b/ge/graph/load/new_model_manager/model_manager.h old mode 100755 new mode 100644 index e3780d5b..fc98d9c2 --- a/ge/graph/load/new_model_manager/model_manager.h +++ b/ge/graph/load/new_model_manager/model_manager.h @@ -169,8 +169,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @brief comment handle function /// ge::Status HandleCommand(const Command &command); - static ge::Status HandleAclProfilingCommand(const Command &command); - static ge::Status HandleProfileCommand(const Command &command); static ge::Status HandleDumpCommand(const Command &command); static ge::Status HandleProfModelSubscribeCommand(const Command &command); static ge::Status HandleProfModelUnsubscribeCommand(const Command &command); @@ -289,7 +287,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); - ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name); + ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name, bool &loaded); ge::Status LaunchCustAicpuSo(); diff --git a/ge/graph/load/new_model_manager/model_utils.cc b/ge/graph/load/new_model_manager/model_utils.cc old mode 100755 new mode 100644 index 34fb7ff3..22a657ad --- a/ge/graph/load/new_model_manager/model_utils.cc +++ b/ge/graph/load/new_model_manager/model_utils.cc @@ -61,7 +61,7 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); continue); - GELOGI("[IMAS]GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); + GELOGI("GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); v_input_size.push_back(tensor_size); } @@ -96,7 +96,7 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); continue); - GELOGI("[IMAS]GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); + GELOGI("GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); v_output_size.push_back(tensor_size); } diff --git a/ge/graph/load/new_model_manager/model_utils.h b/ge/graph/load/new_model_manager/model_utils.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc b/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/event_record_task_info.h b/ge/graph/load/new_model_manager/task_info/event_record_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc b/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h b/ge/graph/load/new_model_manager/task_info/event_wait_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc b/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h b/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc b/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h b/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index b09a4fce..4fb64aab 100644 --- a/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -279,9 +279,10 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr &op_desc, output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; } kernel_hccl_infos[i].inputDataAddr = input_data_addr; - if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER || hccl_type == HCOMREDUCE) { + if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { kernel_hccl_infos[i].outputDataAddr = output_data_addr; - } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { + } else if (hccl_type == HCOMALLREDUCE || + hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) { GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "davinci_model: GetHcomOperationType fail!"); kernel_hccl_infos[i].outputDataAddr = output_data_addr; diff --git a/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc old mode 100755 new mode 100644 index 04607c02..74faeb24 --- a/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -43,6 +43,13 @@ const char *kIsLastNode = "is_last_node"; const char *kIsFirstNode = "is_first_node"; const int64_t kCloseSkt = 100; const uint32_t kAddrLen = sizeof(void *); +const int kBaseInt = 10; +const int kStrtolFail = 0; +const int kArgsInputDesc = 0; +const int kArgsInputAddr = 1; +const int kArgsOutputDesc = 2; +const int kArgsOutputAddr = 3; +const int kArgsAttrHandle = 4; } // namespace namespace ge { @@ -66,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get opcontext stored in model const domi::KernelContext &context = kernel_def.context(); // get kernel_type - kernel_type_ = static_cast(context.kernel_type()); + kernel_type_ = static_cast(context.kernel_type()); // get opdesc op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); GE_CHECK_NOTNULL(op_desc_); @@ -88,13 +95,13 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get bin_file_key const char *bin_file_key = davinci_model_->GetRegisterStub(op_desc_->GetName(), session_graph_model_id); // new aicpu kernel(rtCpuKernelLaunch) no need to check function - if (kernel_type_ == cce::ccKernelType::CCE_AI_CORE) { + if (kernel_type_ == ccKernelType::CCE_AI_CORE) { rtError_t rt_ret; rt_ret = rtGetFunctionByName(const_cast(kernel_def.stub_func().c_str()), &stub_func_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", kernel_def.stub_func().c_str()); return RT_ERROR_TO_GE_STATUS(rt_ret);); - } else if (kernel_type_ == cce::ccKernelType::TE) { + } else if (kernel_type_ == ccKernelType::TE) { rtError_t rt_ret; rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, @@ -111,7 +118,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci ctx_.opIndex2[i] = context.origin_op_index(i); } ctx_.opCount = context.origin_op_index_size(); - if (kernel_type_ == cce::ccKernelType::TE) { + if (kernel_type_ == ccKernelType::TE) { ctx_.opIndex = context.op_index(); uint16_t *args_offset_tmp = reinterpret_cast(const_cast(context.args_offset().data())); if (context.args_offset().size() / sizeof(uint16_t) < 1) { @@ -120,9 +127,9 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci } ret = InitTVMTask(args_offset_tmp[0], kernel_def); - } else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) { + } else if (kernel_type_ == ccKernelType::CUSTOMIZED) { ret = InitAICPUCustomTask(context.op_index(), kernel_def); - } else if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { + } else if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { ret = InitAicpuTask(context.op_index(), kernel_def); } else { if (kernel_def.args().empty() || args_size_ == 0) { @@ -371,9 +378,9 @@ Status KernelTaskInfo::Distribute() { rtError_t rt_ret = RT_ERROR_NONE; char skt_enable_env[MMPA_MAX_PATH] = { 0x00 }; INT32 res = mmGetEnv("SKT_ENABLE", skt_enable_env, MMPA_MAX_PATH); - int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, 10) : 0; + int64_t env_flag = (res == EN_OK) ? strtol(skt_enable_env, nullptr, kBaseInt) : kStrtolFail; bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); - if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { + if (kernel_type_ == ccKernelType::AI_CPU || kernel_type_ == ccKernelType::CUST_AI_CPU) { GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); // blockDim is reserved parameter, set to 1 rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), @@ -749,15 +756,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel return FAILED; } } - *(reinterpret_cast(args + ctx_.argsOffset[0])) = + *(reinterpret_cast(args + ctx_.argsOffset[kArgsInputDesc])) = static_cast(reinterpret_cast(custom_info_.input_descs)); // arg 0 - *(reinterpret_cast(args + ctx_.argsOffset[1])) = + *(reinterpret_cast(args + ctx_.argsOffset[kArgsInputAddr])) = static_cast(reinterpret_cast(custom_info_.input_addrs)); // arg 1 - *(reinterpret_cast(args + ctx_.argsOffset[2])) = + *(reinterpret_cast(args + ctx_.argsOffset[kArgsOutputDesc])) = static_cast(reinterpret_cast(custom_info_.output_descs)); // arg 2 - *(reinterpret_cast(args + ctx_.argsOffset[3])) = + *(reinterpret_cast(args + ctx_.argsOffset[kArgsOutputAddr])) = static_cast(reinterpret_cast(custom_info_.output_addrs)); // arg 3 - *(reinterpret_cast(args + ctx_.argsOffset[4])) = + *(reinterpret_cast(args + ctx_.argsOffset[kArgsAttrHandle])) = static_cast(reinterpret_cast(custom_info_.attr_handle)); // arg 4 rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); @@ -874,8 +881,10 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k return INTERNAL_ERROR; } - if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { - GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_), "launch cust aicpu so failed"); + if (kernel_type_ == ccKernelType::CUST_AI_CPU) { + bool loaded = false; + GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name_, loaded), + "launch cust aicpu so failed"); } // copy args to new host memory @@ -946,7 +955,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k GELOGI("Op debug is open in aicpu task info"); dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } - if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { + if (kernel_type_ == ccKernelType::CUST_AI_CPU) { dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; } @@ -1076,7 +1085,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d Status KernelTaskInfo::SetContext(const domi::KernelDef &kernel_def) { const domi::KernelContext &context = kernel_def.context(); - ctx_.kernelType = static_cast(context.kernel_type()); + ctx_.kernelType = static_cast(context.kernel_type()); ctx_.opId = context.op_id(); ctx_.kernelFuncId = context.kernel_func_id(); ctx_.isFlowtable = context.is_flowtable(); @@ -1161,10 +1170,10 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", error); return FAILED; } - cce::ccStatus_t cc_ret; + ccStatus_t cc_ret; std::string update_kernel_args = "ccUpdateKernelArgs"; - auto cceUpdateKernelArgs = (cce::ccStatus_t(*)(cce::ccOpContext &, uint64_t, uint64_t, uint64_t, void *, uint64_t, - void *))mmDlsym(handle, const_cast(update_kernel_args.c_str())); + auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t, + uint64_t, void *, uint64_t, void *))mmDlsym(handle, const_cast(update_kernel_args.c_str())); if (cceUpdateKernelArgs == nullptr) { GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs"); if (mmDlclose(handle) != 0) { @@ -1189,7 +1198,7 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u GELOGW("Failed to close handle %s", error); return FAILED; } - if (cc_ret != cce::CC_STATUS_SUCCESS) { + if (cc_ret != CC_STATUS_SUCCESS) { GELOGE(CCE_FAILED, "Call cce api failed, ret: 0x%X", cc_ret); return CCE_FAILED; } diff --git a/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index f2945b0b..1f90ede1 100644 --- a/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -43,7 +43,7 @@ class KernelTaskInfo : public TaskInfo { stream_id_(0), so_name_(""), kernel_name_(""), - kernel_type_(cce::ccKernelType::CCE_AI_CORE), + kernel_type_(ccKernelType::CCE_AI_CORE), dump_flag_(RT_KERNEL_DEFAULT), dump_args_(nullptr), op_desc_(nullptr), @@ -75,7 +75,7 @@ class KernelTaskInfo : public TaskInfo { Status Release() override; - cce::ccOpContext *GetCtx() override { return &ctx_; } + ccOpContext *GetCtx() override { return &ctx_; } FusionOpInfo *GetFusionOpInfo() override { return &fusion_op_info_; } @@ -92,7 +92,7 @@ class KernelTaskInfo : public TaskInfo { bool CallSaveDumpInfo() override { return call_save_dump_; }; - cce::ccOpContext ctx_; + ccOpContext ctx_; FusionOpInfo fusion_op_info_; private: @@ -153,7 +153,7 @@ class KernelTaskInfo : public TaskInfo { uint32_t stream_id_; std::string so_name_; std::string kernel_name_; - cce::ccKernelType kernel_type_; + ccKernelType kernel_type_; uint32_t dump_flag_; void *dump_args_; OpDescPtr op_desc_; diff --git a/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc b/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h b/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc b/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h b/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc b/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_active_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h old mode 100755 new mode 100644 index 89642cf8..a72d7de2 --- a/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h +++ b/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h @@ -41,7 +41,7 @@ class StreamSwitchTaskInfo : public TaskInfo { Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; private: - void SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs); + void SetInputAndValuePtr(DavinciModel *davinci_model, const std::vector &input_data_addrs); void *input_ptr_; rtCondition_t cond_; void *value_ptr_; @@ -49,7 +49,7 @@ class StreamSwitchTaskInfo : public TaskInfo { uint32_t true_stream_id_; rtSwitchDataType_t data_type_; static const uint32_t kInputNum = 2; - vector fixed_addr_offset_; + std::vector fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ diff --git a/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc index 63f29f84..65dca3b3 100644 --- a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc +++ b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc @@ -25,10 +25,11 @@ Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { const void *args[] = {this->GetNavTablePtr(), reinterpret_cast(static_cast(this->GetNavTableSize()))}; - rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); return - RT_ERROR_TO_GE_STATUS(rt_ret);) - rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); + rtError_t rt_ret = rtMalloc(reinterpret_cast(&device_args_addr_), sizeof(args), RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) + rt_ret = rtMemcpy(reinterpret_cast(device_args_addr_), sizeof(args), reinterpret_cast(args), + sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, diff --git a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index 69f7b159..4e22cd7c 100644 --- a/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -19,6 +19,8 @@ namespace ge { namespace skt { +const size_t kFusedKernelMinimumSize = 2; +const size_t kFusedKernelSizeUnit = 2; SuperKernelFactory &SuperKernelFactory::GetInstance() { static SuperKernelFactory factory; return factory; @@ -79,17 +81,17 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list return FAILED; } - if (super_kernel_size < 2) { + if (super_kernel_size < kFusedKernelMinimumSize) { GELOGW( "SKT: the number of kernels being fused must be greater than or " "equal to 2"); return FAILED; } GELOGI("SKT: superkernel start fuse, superkernel size %zu.", stub_func_list.size()); - const size_t nav_table_len = 2 * stub_func_list.size(); + const size_t nav_table_len = kFusedKernelSizeUnit * stub_func_list.size(); std::unique_ptr nav_table(new(std::nothrow) uint64_t[nav_table_len]); GE_CHECK_NOTNULL(nav_table); - uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); + uint64_t nav_table_size = kFusedKernelSizeUnit * stub_func_list.size() * sizeof(int64_t); rtError_t rt_ret; void *hbm_nav_table_addr = nullptr; @@ -101,21 +103,21 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); // store two uint64_t address // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = static_cast(reinterpret_cast(sub_device_func)) / 4; - GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = static_cast(reinterpret_cast(args_addr_list[i])); - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); + nav_table[i * kFusedKernelSizeUnit] = static_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: CALL offet %lu", nav_table[i * kFusedKernelSizeUnit]); + nav_table[i * kFusedKernelSizeUnit + 1] = static_cast(reinterpret_cast(args_addr_list[i])); + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * kFusedKernelSizeUnit + 1]); } - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); + rt_ret = rtMalloc(reinterpret_cast(&hbm_nav_table_addr), nav_table_size, RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); return RT_ERROR_TO_GE_STATUS(rt_ret);) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table.get(), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(reinterpret_cast(hbm_nav_table_addr), nav_table_size, + reinterpret_cast(nav_table.get()), nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret); GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) // Create the necessary metadata for the super kernel - h = std::unique_ptr( - new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); + h = + std::unique_ptr(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); return SUCCESS; } } // namespace skt diff --git a/ge/graph/load/new_model_manager/task_info/task_info.cc b/ge/graph/load/new_model_manager/task_info/task_info.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/task_info/task_info.h b/ge/graph/load/new_model_manager/task_info/task_info.h index d296d29e..26f22564 100644 --- a/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/ge/graph/load/new_model_manager/task_info/task_info.h @@ -20,7 +20,7 @@ #include #include "cce/customize.h" -#include "cce/taskdown_common.hpp" +#include "framework/common/taskdown_common.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/load/new_model_manager/ts_mem_mall.h" #include "graph/load/new_model_manager/task_info/task_info_factory.h" @@ -63,8 +63,8 @@ struct RuntimeParam { }; typedef struct FusionOpInfo { - vector original_op_names; - string op_name; + std::vector original_op_names; + std::string op_name; uint32_t op_index; uint32_t stream_id; } FusionOpInfo; @@ -87,7 +87,7 @@ class TaskInfo { virtual Status Release() { return SUCCESS; } - virtual cce::ccOpContext *GetCtx() { return nullptr; } + virtual ccOpContext *GetCtx() { return nullptr; } virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } diff --git a/ge/graph/load/new_model_manager/tbe_handle_store.cc b/ge/graph/load/new_model_manager/tbe_handle_store.cc old mode 100755 new mode 100644 diff --git a/ge/graph/load/new_model_manager/ts_mem_mall.h b/ge/graph/load/new_model_manager/ts_mem_mall.h index 42ad3957..64a64930 100644 --- a/ge/graph/load/new_model_manager/ts_mem_mall.h +++ b/ge/graph/load/new_model_manager/ts_mem_mall.h @@ -25,7 +25,7 @@ #include "framework/common/debug/ge_log.h" namespace { -constexpr uint32_t kMaxTsMemBlock = 2 * 1024 * 1024; // Max block 2M +constexpr uint32_t kMaxTsMemBlock = 2097152; // Max block 2M 2 * 1024 * 1024 constexpr uint32_t kTsMemAligment = 64; // Malloc for 64 bits align constexpr uint32_t kTsMemAlignMask = kTsMemAligment - 1; } diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.cc b/ge/graph/load/new_model_manager/zero_copy_offset.cc index 970b292c..9cd3f30b 100644 --- a/ge/graph/load/new_model_manager/zero_copy_offset.cc +++ b/ge/graph/load/new_model_manager/zero_copy_offset.cc @@ -35,6 +35,7 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p", op_desc->GetName().c_str(), output_size, virtual_addr); basic_addr_ = virtual_addr; + op_name_ = op_desc->GetName(); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, @@ -82,6 +83,7 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector &input_size_list GELOGD("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); basic_addr_ = virtual_addr_list[idx]; + op_name_ = op_desc->GetName(); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, diff --git a/ge/graph/load/new_model_manager/zero_copy_offset.h b/ge/graph/load/new_model_manager/zero_copy_offset.h index 025d1b14..fa80f28b 100644 --- a/ge/graph/load/new_model_manager/zero_copy_offset.h +++ b/ge/graph/load/new_model_manager/zero_copy_offset.h @@ -66,9 +66,12 @@ class ZeroCopyOffset { int64_t GetDataSize() const { return data_size_; } // value of *outside_addrs_ from davinci_model std::vector>> &GetOutsideAddrs() { return outside_addrs_; } + // name of op + std::string GetOpName() const { return op_name_; } private: void *basic_addr_ = nullptr; + std::string op_name_; uint32_t data_count_ = 0; std::vector> data_info_; vector relative_offset_; @@ -80,4 +83,4 @@ class ZeroCopyOffset { std::vector zero_copy_relative_offset_; }; } // namespace ge -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ \ No newline at end of file +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ diff --git a/ge/graph/load/new_model_manager/zero_copy_task.cc b/ge/graph/load/new_model_manager/zero_copy_task.cc old mode 100755 new mode 100644 index 9b42d563..2609cb4b --- a/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -131,7 +131,7 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma auto dst_addr = static_cast(buffer_addr); GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); - *(uintptr_t *)(args_info + offset) = reinterpret_cast(dst_addr); + *reinterpret_cast(args_info + offset)= reinterpret_cast(dst_addr); is_updated_ = true; } } diff --git a/ge/graph/manager/graph_caching_allocator.cc b/ge/graph/manager/graph_caching_allocator.cc index 4ba39ca8..d6027a08 100644 --- a/ge/graph/manager/graph_caching_allocator.cc +++ b/ge/graph/manager/graph_caching_allocator.cc @@ -25,13 +25,13 @@ namespace ge { const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, - 8 * kMByteSize, - 32 * kMByteSize, - 128 * kMByteSize, + kBinSizeUnit8 * kMByteSize, + kBinSizeUnit32 * kMByteSize, + kBinSizeUnit128 * kMByteSize, kGByteSize, - 4 * kGByteSize, - 16 * kGByteSize, - 26 * kGByteSize}; + kBinSizeUnit4 * kGByteSize, + kBinSizeUnit16 * kGByteSize, + kBinSizeUnit26 * kGByteSize}; static bool BlockComparator(const Block *left, const Block *right) { if (left->size != right->size) { diff --git a/ge/graph/manager/graph_caching_allocator.h b/ge/graph/manager/graph_caching_allocator.h index dc4af753..e024d5cd 100644 --- a/ge/graph/manager/graph_caching_allocator.h +++ b/ge/graph/manager/graph_caching_allocator.h @@ -34,10 +34,17 @@ namespace ge { constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes +constexpr size_t kBinSizeUnit4 = 4; +constexpr size_t kBinSizeUnit8 = 8; +constexpr size_t kBinSizeUnit16 = 16; +constexpr size_t kBinSizeUnit26 = 26; +constexpr size_t kBinSizeUnit32 = 32; +constexpr size_t kBinSizeUnit128 = 128; + constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr size_t kKByteSize = 1024; -constexpr size_t kMByteSize = 1024 * 1024; -constexpr size_t kGByteSize = 1024 * 1024 * 1024; +constexpr size_t kMByteSize = 1048576; // 1024 * 1024 +constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 static const uint32_t kNumBins = 8; diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc old mode 100755 new mode 100644 index 87070e79..bd476ad5 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -68,6 +68,7 @@ #include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/mark_graph_unknown_status_pass.h" +#include "graph/passes/dynamic_single_op_reset_shape_pass.h" #include "graph/passes/merge_pass.h" #include "graph/passes/merge_input_memcpy_pass.h" #include "graph/passes/merge_to_stream_merge_pass.h" @@ -533,9 +534,8 @@ Status GraphManager::CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_gr return SUCCESS; } -Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, - Graph2SubGraphInfoList &sub_graph_map, - uint64_t session_id) { +Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, + Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) { GE_CHECK_NOTNULL(compute_graph); // use default 16 multi thread const uint32_t thread_num = 16; @@ -550,14 +550,14 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr (void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); } std::future f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, - compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext()); + compute_graph->GetGraphID(), subgraph, compute_graph, session_id, + GetThreadLocalContext()); if (!f.valid()) { GELOGE(FAILED, "Future is invalid"); return FAILED; } vector_future.emplace_back(std::move(f)); } - for (auto &function_graph : compute_graph->GetAllSubgraphs()) { auto subgraph_list = sub_graph_map[function_graph]; for (const auto &subgraph : subgraph_list) { @@ -650,63 +650,25 @@ Status GraphManager::ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { GE_CHECK_NOTNULL(compute_graph); - auto sub_graph_map = partitioner.GetSubGraphMap(); - std::string buffer_optimize; - graphStatus graph_status = ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); - bool need_lx_fusion = (graph_status == GRAPH_SUCCESS) && (buffer_optimize != kOffOptimize); - if (options_.build_mode.empty() && need_lx_fusion) { - GELOGI("Enter normal mode with buffer_optimize:%s.", buffer_optimize.c_str()); - /// 1. Copy subgraph for buffer optimize while lx fusion failed. - /// 2. Set graph with attr "lx_fusion" for fusion optimize. - std::unordered_map copy_graphs; - GE_TIMESTAMP_START(CopySubGraphAndMarkFusion); - Status ret = CopySubGraphAndMarkFusion(compute_graph, sub_graph_map, copy_graphs); - GE_TIMESTAMP_EVENT_END(CopySubGraphAndMarkFusion, "SetSubgraph:CopySubGraphAndMarkFusion"); - if (ret != SUCCESS) { - GELOGE(ret, "CopySubGraphAndMarkFusion failed."); - return ret; - } - - // Multiply optimize subgraph with lx fusion - ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "Multiply optimize subgraph with lx fusion failed."); - return ret; - } - - // Check whether all subgraph lx fusion success - GE_TIMESTAMP_START(CheckAllFusionOptimizeSuccess); - if (CheckAllFusionOptimizeSuccess(compute_graph, sub_graph_map)) { - GE_TIMESTAMP_EVENT_END(CheckAllFusionOptimizeSuccess, "SetSubgraph:CheckAllFusionOptimizeSuccess"); - return SUCCESS; - } - - // Replace subgraph with original graph for lx buffer - ret = ReplaceSubgraphWithOriGraph(compute_graph, sub_graph_map, copy_graphs); - if (ret != SUCCESS) { - GELOGE(ret, "Replace subgraph with original graph failed."); - return ret; - } + PassManager pass_for_dynamic_shape_reset_optimize; + GE_CHK_STATUS_RET(pass_for_dynamic_shape_reset_optimize.AddPass( + "SetSubgraph::AfterSetSubgraph::DynamicSingleOpResetShapePass", new (std::nothrow) DynamicSingleOpResetShapePass)) + GE_TIMESTAMP_START(pass_for_dynamic_shape_reset_optimize); + Status ret = pass_for_dynamic_shape_reset_optimize.Run(compute_graph); + GE_TIMESTAMP_END(pass_for_dynamic_shape_reset_optimize, "SetSubgraph::AfterSetSubgraph"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes when optimize subgraph failed"); + return ret; + } - // Multiply optimize subgraph with lx buffer - ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "Multiply optimize subgraph with lx buffer failed."); - return ret; - } - } else { - /// Multiply optimize subgraph: - /// 1. run lx buffer while build_mode is normal and buffer_optimize is empty or "off_optimize"; - /// 2. run lx fusion or buffer according build_mode and build_step in fe. - GELOGD("Directly optimize subgraph with build mode:%s, and step:%s, buffer_optimize:%s.", - options_.build_mode.c_str(), - options_.build_step.c_str(), - buffer_optimize.c_str()); - Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "Multiply optimize subgraph with lx buffer"); - return ret; - } + auto sub_graph_map = partitioner.GetSubGraphMap(); + GELOGD("Directly optimize subgraph with build mode:%s, and step:%s.", + options_.build_mode.c_str(), + options_.build_step.c_str()); + ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); + if (ret != SUCCESS) { + GELOGE(ret, "Multiply optimize subgraph failed"); + return ret; } return SUCCESS; } @@ -2515,7 +2477,6 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager GetContext().SetSessionId(session_id); GetThreadLocalContext() = ge_context; graph_manager->UpdateLocalOmgContext(root_graph_id); - ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", @@ -2523,6 +2484,10 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager pthread_self()); GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore"); GE_CHECK_NOTNULL(compute_graph_tmp); + if (!AttrUtils::SetInt(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_ID, root_graph_id)) { + GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id); + return FAILED; + } compute_graph_tmp->SetSessionID(session_id); Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp, compute_graph, @@ -2688,9 +2653,7 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } // it will not execute graph preprocess, optimize, parition, build if the graph has built successful. - GELOGI("Start for run graph async."); - GeRootModelPtr ge_root_model = nullptr; if (graph_manager->IsGraphNeedBuild(graph_node)) { if (graph_node->GetBuildFlag()) { diff --git a/ge/graph/manager/graph_mem_allocator.cc b/ge/graph/manager/graph_mem_allocator.cc old mode 100755 new mode 100644 diff --git a/ge/graph/manager/graph_var_manager.cc b/ge/graph/manager/graph_var_manager.cc old mode 100755 new mode 100644 index be7d4eb2..84a07069 --- a/ge/graph/manager/graph_var_manager.cc +++ b/ge/graph/manager/graph_var_manager.cc @@ -280,9 +280,9 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin return PARAM_INVALID; } uint64_t free_size = total_size_ - var_mem_size_; - if (free_size < (size + kSessionMemAlignSize * 2)) { + if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) { GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", - size + kSessionMemAlignSize * 2 + var_mem_size_, total_size_); + size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_); return PARAM_INVALID; } diff --git a/ge/graph/manager/graph_var_manager.h b/ge/graph/manager/graph_var_manager.h old mode 100755 new mode 100644 index b4f6aca3..fcbc92c5 --- a/ge/graph/manager/graph_var_manager.h +++ b/ge/graph/manager/graph_var_manager.h @@ -42,6 +42,7 @@ const size_t kGraphMemoryBuffer = 4UL * 1024UL * 1024UL * 1024UL; const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; const uint64_t kSessionMemAlignSize = 512; +const size_t kSessionMemAlignUnit = 2; enum MemStatus { NORMAL = 0, diff --git a/ge/graph/manager/host_mem_manager.cc b/ge/graph/manager/host_mem_manager.cc index d4aceddd..c99c9e87 100644 --- a/ge/graph/manager/host_mem_manager.cc +++ b/ge/graph/manager/host_mem_manager.cc @@ -106,7 +106,7 @@ Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_add GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); return INTERNAL_ERROR; } - base_addr = reinterpret_cast(reinterpret_cast(var_memory_base_map_[op_name].device_address)); + base_addr = static_cast(reinterpret_cast(var_memory_base_map_[op_name].device_address)); data_size = var_memory_base_map_[op_name].mem_size; return SUCCESS; } diff --git a/ge/graph/manager/trans_var_data_utils.h b/ge/graph/manager/trans_var_data_utils.h old mode 100755 new mode 100644 diff --git a/ge/graph/manager/util/debug.cc b/ge/graph/manager/util/debug.cc index 45c070c6..2c930d1f 100644 --- a/ge/graph/manager/util/debug.cc +++ b/ge/graph/manager/util/debug.cc @@ -32,7 +32,8 @@ Debug::~Debug() = default; void Debug::DumpProto(const Message &proto, const char *file) { std::string file_path = RealPath(file); - int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD | M_UMASK_OTHREAD); + int fd = mmOpen2(file_path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, M_IRUSR | M_IWUSR | M_UMASK_GRPREAD | + M_UMASK_OTHREAD); if (fd == -1) { GELOGW("Write %s failed", file_path.c_str()); return; diff --git a/ge/graph/manager/util/debug.h b/ge/graph/manager/util/debug.h old mode 100755 new mode 100644 diff --git a/ge/graph/manager/util/hcom_util.cc b/ge/graph/manager/util/hcom_util.cc index 487b24af..50fa9936 100644 --- a/ge/graph/manager/util/hcom_util.cc +++ b/ge/graph/manager/util/hcom_util.cc @@ -263,7 +263,8 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { + if (op_desc->GetType() == HCOMBROADCAST || + op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); int64_t root_id = 0; Status dmrt = GetHcclRootId(op_desc, root_id); diff --git a/ge/graph/optimize/graph_optimize.h b/ge/graph/optimize/graph_optimize.h old mode 100755 new mode 100644 diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 2fabc035..dfc6c9df 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -26,6 +26,13 @@ namespace { using namespace ge; const int kIdentityAnchorIndex = 0; +const size_t kSerialStringVecSize = 4; + +const int kCaseReadOnly = 0; +const int kCaseScopeWriteable = 2; +const int kCaseWriteable = 3; +const int kCaseInvalidRWType = 5; + // rw type of input. enum class InputRWType { kReadOnly, // Normal op input only read @@ -55,7 +62,7 @@ thread_local map node_rwtype_map_; /// @return rw_type_name /// static std::string InputRWTypeToSerialString(InputRWType rw_type) { - const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"}; + const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"}; return names[static_cast(rw_type)]; } @@ -65,7 +72,7 @@ static std::string InputRWTypeToSerialString(InputRWType rw_type) { /// @return rw_type_name /// static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { - const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"}; + const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"}; return names[static_cast(rw_type)]; } @@ -118,13 +125,13 @@ InputRWType GetInputRwTypeInConflict(const std::set &rw_type_set) { } switch (total_rw_type) { - case 0: + case kCaseReadOnly: return InputRWType::kReadOnly; // all input rw type is readonly - case 2: + case kCaseScopeWriteable: return InputRWType::kScopeWriteable; // readonly 2 scope_writeable - case 3: + case kCaseWriteable: return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable - case 5: + case kCaseInvalidRWType: return InputRWType::kInvalidRWType; // writeable 2 scope_writeable default: return InputRWType::kInvalidRWType; @@ -643,7 +650,7 @@ Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); GE_CHK_STATUS_RET(ret, "Fail to insert identity."); GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); + pre_node->GetName().c_str(), node->GetName().c_str()); } } } diff --git a/ge/graph/partition/dynamic_shape_partition.cc b/ge/graph/partition/dynamic_shape_partition.cc old mode 100755 new mode 100644 diff --git a/ge/graph/partition/engine_place.cc b/ge/graph/partition/engine_place.cc old mode 100755 new mode 100644 diff --git a/ge/graph/partition/engine_place.h b/ge/graph/partition/engine_place.h old mode 100755 new mode 100644 diff --git a/ge/graph/partition/graph_partition.cc b/ge/graph/partition/graph_partition.cc old mode 100755 new mode 100644 index 6a1fbb34..fbc13920 --- a/ge/graph/partition/graph_partition.cc +++ b/ge/graph/partition/graph_partition.cc @@ -614,32 +614,32 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vectorSetParentNode(compute_graph->GetParentNode()); - (void) AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); - auto sgi = MakeShared(); - if (sgi == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); - return FAILED; - } - // set engine name - sgi->SetEngineName(engine_name); - // set stream label - string sub_graph_stream; - if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) { - sgi->SetStreamLabel(sub_graph_stream); - } - /// for now inputFlag is the same before and after partition. It should - /// be changed according to the real partition - std::vector sub_graph_input(graph_info_.input_size_, true); - std::vector sub_graph_output(graph_info_.output_size_, true); - sgi->SetSubGraph(sub_graph); - sgi->SetOutputFlag(sub_graph_output); - sgi->SetInputFlag(sub_graph_input); - sgi->SetOutputContext(graph_info_.output_name_); - AddEndPldInformationToSubGraphInfo(sgi); - GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s", - engine_name.c_str(), - sub_graph->GetName().c_str(), - sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str()); + (void)AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); + GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), + compute_graph->GetName().c_str()); + auto sgi = MakeShared(); + if (sgi == nullptr) { + GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); + return FAILED; + } + // set engine name + sgi->SetEngineName(engine_name); + // set stream label + string sub_graph_stream; + if (AttrUtils::GetStr(sub_graph->GetDirectNode().at(0)->GetOpDesc(), ATTR_NAME_STREAM_LABEL, sub_graph_stream)) { + sgi->SetStreamLabel(sub_graph_stream); + } + /// for now inputFlag is the same before and after partition. It should + /// be changed according to the real partition + std::vector sub_graph_input(graph_info_.input_size_, true); + std::vector sub_graph_output(graph_info_.output_size_, true); + sgi->SetSubGraph(sub_graph); + sgi->SetOutputFlag(sub_graph_output); + sgi->SetInputFlag(sub_graph_input); + sgi->SetOutputContext(graph_info_.output_name_); + AddEndPldInformationToSubGraphInfo(sgi); + GELOGI("[GraphPartitioner]: subGraph engine name is %s, graph name is %s, stream label is %s", engine_name.c_str(), + sub_graph->GetName().c_str(), sgi->GetStreamLabel().empty() ? "null" : sgi->GetStreamLabel().c_str()); if (engine_name != input_subgraph_name) { // do not add Data subGraph into SubGraphInfo output_subgraphs.push_back(sgi); } else { diff --git a/ge/graph/passes/aicpu_constant_folding_pass.h b/ge/graph/passes/aicpu_constant_folding_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/assert_pass.h b/ge/graph/passes/assert_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc old mode 100755 new mode 100644 index 60742eb1..7c6ed8ce --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -74,10 +74,88 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { return SUCCESS; } +// just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input +bool AtomicAddrCleanPass::CheckAtomicFromOpsKernel(const NodePtr &node) { + // 1.Check if isAtomic attrs exist for HCOM + std::shared_ptr instance_ptr = GELib::GetInstance(); + if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { + GELOGW("GELib not initialized, atomic from ops kernel judge false, node_name: %s", node->GetName().c_str()); + return false; + } + + OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); + vector op_info_vec = ops_kernel_manager.GetOpsKernelInfo(node->GetType()); + for (const auto &op_info : op_info_vec) { + if (op_info.isAtomic) { + // check peer input is DATA + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + if (in_data_anchor->GetPeerOutAnchor() != nullptr && + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { + auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); + if (peer_in_node->GetType() == DATA) { + GELOGI("Recognized atomic op %s from %s engine and input is DATA.", node->GetName().c_str(), + op_info.engine.c_str()); + return false; + } + } + } + GELOGI("Recognized atomic op %s from %s engine.", node->GetName().c_str(), op_info.engine.c_str()); + hcom_node_vec_.push_back(node); + return true; + } + } + return false; +} + +bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index) { + auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index); + if (out_data_anchor == nullptr) { + return false; + } + + for (auto input_anchor : out_data_anchor->GetPeerInDataAnchors()) { + auto output_node = input_anchor->GetOwnerNode(); + // just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input + // hccl's attr ATOMIC_ATTR_INPUT_INDEX mark on CalcOpRunningParam, can't be get here + if (CheckAtomicFromOpsKernel(output_node)) { + return true; + } + } + return false; +} + +bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { + OpDescPtr op_desc = node->GetOpDesc(); + std::map> node_workspace_offset; + bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); + bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); + node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); + if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { + std::vector atomic_output_index; + (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); + bool is_all_output_peer_also_atomic = true; + for (const auto &output_index : atomic_output_index) { + if (!IsOutputIndexPeerInputAtomic(node, output_index)) { + is_all_output_peer_also_atomic = false; + break; + } + } + if (is_all_output_peer_also_atomic) { + GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str()); + return true; + } + } + return false; +} + Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector &atomic_node_vec) { // Loop graph , insert clean node follow atomic node int index = 0; for (const auto &node : atomic_node_vec) { + if (CheckSkipInsertInLoopGraph(node)) { + continue; + } + // Insert atomic clean op NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); if (clean_addr_node == nullptr) { @@ -249,32 +327,10 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { return false; } // 1.Check if isAtomic attrs exist for HCOM - std::shared_ptr instance_ptr = GELib::GetInstance(); - if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { - GELOGW("GELib not initialized"); - return false; + if (CheckAtomicFromOpsKernel(node)) { + return true; } - OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); - vector op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); - for (const auto &op_info : op_info_vec) { - if (op_info.isAtomic) { - GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); - // check peer input is DATA - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor->GetPeerOutAnchor() != nullptr && - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { - auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - if (peer_in_node->GetType() == DATA) { - GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); - return false; - } - } - } - hcom_node_vec_.push_back(node); - return true; - } - } // 2.Check atomic attr in node std::map> node_workspace_offset; bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); diff --git a/ge/graph/passes/atomic_addr_clean_pass.h b/ge/graph/passes/atomic_addr_clean_pass.h old mode 100755 new mode 100644 index ad60b7b5..8138d511 --- a/ge/graph/passes/atomic_addr_clean_pass.h +++ b/ge/graph/passes/atomic_addr_clean_pass.h @@ -84,6 +84,11 @@ class AtomicAddrCleanPass : public GraphPass { Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector &atomic_node_vec, std::vector &common_atomic_nodes); + bool CheckAtomicFromOpsKernel(const NodePtr &node); + + bool IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index); + + bool CheckSkipInsertInLoopGraph(const NodePtr &node); vector hcom_node_vec_; bool is_loop_graph_ = false; diff --git a/ge/graph/passes/attach_stream_label_pass.cc b/ge/graph/passes/attach_stream_label_pass.cc index b04643a4..c0e0f669 100644 --- a/ge/graph/passes/attach_stream_label_pass.cc +++ b/ge/graph/passes/attach_stream_label_pass.cc @@ -24,11 +24,7 @@ Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { FindNodes(graph); for (const auto &node : need_label_nodes_) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); - } + GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); } GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); @@ -55,13 +51,15 @@ Status AttachStreamLabelPass::ClearStatus() { /// void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { for (const NodePtr &node : graph->GetDirectNode()) { - const std::string &type = node->GetType(); - if (type == STREAMSWITCH) { + const auto &op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + const std::string &type = op_desc->GetType(); + if ((type == STREAMSWITCH) && op_desc->HasAttr(ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG)) { stream_switch_nodes_.emplace_back(node); - } else if (type == STREAMMERGE) { - if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - need_label_nodes_.emplace_back(node); - } + } else if ((type == STREAMMERGE) && !op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + need_label_nodes_.emplace_back(node); } else if ((type == ENTER) || (type == REFENTER)) { enter_nodes_.emplace_back(node); } @@ -83,11 +81,15 @@ void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { /// Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { std::string stream_label; + if (AttachFlag(node, stream_label) != SUCCESS) { + GELOGE(FAILED, "Attach flag for node %s failed.", node->GetName().c_str()); + return FAILED; + } + std::unordered_set branch_nodes; std::unordered_set visited; std::stack nodes; nodes.push(node); - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; while (!nodes.empty()) { NodePtr cur_node = nodes.top(); @@ -95,10 +97,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { if (visited.count(cur_node) > 0) { continue; } - if (AttachFlag(cur_node, stream_label) != SUCCESS) { - GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); - return FAILED; - } const std::string &type = cur_node->GetType(); for (const auto &out_node : cur_node->GetOutAllNodes()) { @@ -115,10 +113,6 @@ Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { visited.insert(cur_node); } - if (node->GetType() == STREAMSWITCH) { - GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); - } - for (const NodePtr &tmp_node : branch_nodes) { GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); @@ -148,11 +142,10 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); stream_label += (value ? "_t" : "_f"); + GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); } else if (type == STREAMMERGE) { stream_label = node->GetName(); GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); - } else if ((type == EXIT) || (type == REFEXIT)) { - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); } return SUCCESS; @@ -166,12 +159,13 @@ Status AttachStreamLabelPass::UpdateEnterNode() { std::unordered_map> enter_active_map; for (const auto &enter_node : enter_nodes_) { for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() == STREAMACTIVE) { - if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - enter_active_map[out_ctrl_node].emplace_back(enter_node); - } + if (out_ctrl_node->GetType() != STREAMACTIVE) { + continue; + } + if (enter_active_map.find(out_ctrl_node) == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + enter_active_map[out_ctrl_node].emplace_back(enter_node); } } } @@ -226,9 +220,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no std::string stream_label; GE_CHECK_NOTNULL(active_node); (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); - if (stream_label.empty()) { - GELOGW("stream_label of enter_active & enter_nodes is empty."); + GELOGD("stream_label of enter_active %s is empty.", active_node->GetName().c_str()); return SUCCESS; } @@ -238,7 +231,6 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_no GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); } } - GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); return SUCCESS; } diff --git a/ge/graph/passes/attach_stream_label_pass.h b/ge/graph/passes/attach_stream_label_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/cast_translate_pass.h b/ge/graph/passes/cast_translate_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/compile_nodes_pass.cc b/ge/graph/passes/compile_nodes_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/cond_remove_pass.cc b/ge/graph/passes/cond_remove_pass.cc index e8d1493f..bf2e1170 100644 --- a/ge/graph/passes/cond_remove_pass.cc +++ b/ge/graph/passes/cond_remove_pass.cc @@ -37,6 +37,12 @@ Status CondRemovePass::Run(NodePtr &node) { OutDataAnchorPtr cond_out_anchor = nullptr; InDataAnchorPtr cond_in_anchor = nullptr; Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); + if (ret == NOT_CHANGED) { + return SUCCESS; + } else if (ret != SUCCESS) { + GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); + return FAILED; + } int32_t cond_index = 0; GELOGD("Handle cond remove for node %s.", node->GetOpDesc()->GetName().c_str()); bool if_cond_const = CheckIfCondConstInput(cond_out_anchor, cond_in_anchor, cond_index); @@ -322,11 +328,11 @@ Status CondRemovePass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, std::string type = node->GetType(); if ((kIfOpTypes.count(type) != 0) || (kCaseOpTypes.count(type) != 0)) { if (GetCondInfoForIfCase(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { - GELOGE(FAILED, "Get cond_info for if node failed."); + GELOGE(FAILED, "Get cond_info for if/case node failed."); return FAILED; } } else { - GELOGD("no need cond_pass for node %s.", node->GetName().c_str()); + GELOGD("no need cond_remove_pass for node %s.", node->GetName().c_str()); return NOT_CHANGED; } diff --git a/ge/graph/passes/constant_fuse_same_pass.h b/ge/graph/passes/constant_fuse_same_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/control_trigger_pass.h b/ge/graph/passes/control_trigger_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/ctrl_edge_transfer_pass.cc b/ge/graph/passes/ctrl_edge_transfer_pass.cc old mode 100755 new mode 100644 index f53dc7be..a538a10c --- a/ge/graph/passes/ctrl_edge_transfer_pass.cc +++ b/ge/graph/passes/ctrl_edge_transfer_pass.cc @@ -38,7 +38,6 @@ namespace ge { * \ / * B */ - Status CtrlEdgeTransferPass::Run(ge::ComputeGraphPtr graph) { GELOGD("CtrlEdgeTransferPass start running"); GE_CHECK_NOTNULL(graph); diff --git a/ge/graph/passes/ctrl_edge_transfer_pass.h b/ge/graph/passes/ctrl_edge_transfer_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 4ec8743e..5bbd2fb1 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -21,6 +21,7 @@ namespace ge { namespace { +const int kDataIndexOffset = 2; Status MappingSubgraphInput(const ComputeGraphPtr &graph, const std::function &input) { for (const auto &node : graph->GetDirectNode()) { if (node->GetType() != DATA) { @@ -111,7 +112,7 @@ Status ParseSubgraphPostFnWhile(const string &subgraph_name, const ComputeGraphP Status ParseSubgraphPostFnFor(const string &subgraph_name, const ComputeGraphPtr &graph) { return MappingSubgraphIndex(graph, - [](int data_index) { return (data_index == 0) ? 0 : data_index + 2; }, + [](int data_index) { return (data_index == 0) ? 0 : data_index + kDataIndexOffset; }, [](int retval_index) { return retval_index; }); } diff --git a/ge/graph/passes/dimension_adjust_pass.cc b/ge/graph/passes/dimension_adjust_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/dimension_adjust_pass.h b/ge/graph/passes/dimension_adjust_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/dimension_compute_pass.cc b/ge/graph/passes/dimension_compute_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/dropout_pass.h b/ge/graph/passes/dropout_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc new file mode 100644 index 00000000..3e6377c7 --- /dev/null +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/dynamic_single_op_reset_shape_pass.h" +#include "common/ge_inner_error_codes.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +namespace { +const int64_t kDynamicShapeDim = -2; +const char *const kEngineNameAiCpu = "DNN_VM_AICPU_ASCEND"; +const char *const kEngineNameAiCpuTf = "DNN_VM_AICPU"; +} // namespace +Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + + std::shared_ptr instance = ge::GELib::GetInstance(); + if (instance == nullptr || !instance->InitFlag()) { + GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Run CompileNodesPass failed."); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + + // pass if graph has not aicpu node. + bool is_not_aicpu = false; + if (CheckAllAicpuNodes(graph, is_not_aicpu) != SUCCESS) { + GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Check if graph has not aicpu node failed."); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + if (is_not_aicpu) { + GELOGI("The graph [%s] has not aicpu node, whose aicpu nodes would not be reset dynamic shape", + graph->GetName().c_str()); + return SUCCESS; + } + + for (const auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + // pass input and output node + if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || + node->GetType() == NETOUTPUT) { + continue; + } + + // pass node without attr: ATTR_DYNAMIC_SHAPE_SINGLE_AICPU + bool single_aicpu_unknown = false; + if (!AttrUtils::GetBool(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, single_aicpu_unknown) || + !single_aicpu_unknown) { + continue; + } + + // reset aicpu shape to unknown shape + auto op_desc = node->GetOpDesc(); + if (ResetOpShape(op_desc) != SUCCESS) { + GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Reset node[%s] dynamic shapr failed.", node->GetName().c_str()); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + GELOGD("Reset dynamic aicpu node [%s] shape success!", node->GetName().c_str()); + } + + GELOGD("Reset dynamic aicpu nodes shape of graph [%s] success!", graph->GetName().c_str()); + return SUCCESS; +} + +Status DynamicSingleOpResetShapePass::CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu) { + is_not_aicpu = false; + for (const auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + // pass input and output node + if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || + node->GetType() == NETOUTPUT) { + continue; + } + + // find if there are aicpu nodes. + auto op_desc = node->GetOpDesc(); + string engine_name = op_desc->GetOpEngineName(); + if (engine_name.empty()) { + GELOGE(GRAPH_FAILED, "Get engine failed of node[%s].", node->GetName().c_str()); + return GRAPH_FAILED; + } + if (engine_name != kEngineNameAiCpu && engine_name != kEngineNameAiCpuTf) { + is_not_aicpu = true; + return SUCCESS; + } + } + return SUCCESS; +} + +bool DynamicSingleOpResetShapePass::CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc) { + bool is_const = false; + (void)AttrUtils::GetBool(input_tensor_desc, CONST_ATTR_NAME_INPUT, is_const); + return is_const; +} + +Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc); + std::vector dynamic_shape_dims = {kDynamicShapeDim}; + GeShape dynamic_shape(dynamic_shape_dims); + for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { + auto input_desc = op_desc->MutableInputDesc(static_cast(i)); + GE_CHECK_NOTNULL(input_desc); + // pass scalar input desc + auto dims_ori = input_desc->GetShape().GetDims(); + if (dims_ori.size() == 0) { + continue; + } + // pass const input + if (CheckIfConstInput(input_desc)) { + continue; + } + input_desc->SetShape(dynamic_shape); + } + for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) { + auto output_desc = op_desc->MutableOutputDesc(static_cast(i)); + GE_CHECK_NOTNULL(output_desc); + // pass scalar input desc + auto output_dims_ori = output_desc->GetShape().GetDims(); + if (output_dims_ori.size() == 0) { + continue; + } + output_desc->SetShape(dynamic_shape); + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/ge/graph/passes/dynamic_single_op_reset_shape_pass.h b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h new file mode 100644 index 00000000..659bed9c --- /dev/null +++ b/ge/graph/passes/dynamic_single_op_reset_shape_pass.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ +#define GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ +#include "graph/graph.h" +#include "inc/graph_pass.h" +#include "init/gelib.h" + +namespace ge { +class DynamicSingleOpResetShapePass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + + private: + Status ResetOpShape(OpDescPtr &op_desc); + Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu); + bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ diff --git a/ge/graph/passes/end_of_sequence_add_control_pass.cc b/ge/graph/passes/end_of_sequence_add_control_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/enter_pass.cc b/ge/graph/passes/enter_pass.cc index 206d271c..afeca78f 100644 --- a/ge/graph/passes/enter_pass.cc +++ b/ge/graph/passes/enter_pass.cc @@ -16,6 +16,7 @@ #include "graph/passes/enter_pass.h" +#include "graph/debug/ge_attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "graph/utils/graph_utils.h" @@ -72,33 +73,25 @@ Status EnterPass::Run(NodePtr &node) { } Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) { - auto out_nodes_of_in_node = in_node->GetOutAllNodes(); - if (out_nodes_of_in_node.size() != kOutNodesNum) { + if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) { return SUCCESS; } - - if (!node->GetOutControlNodes().empty()) { + bool is_constant_flag = true; + (void)AttrUtils::GetBool(node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant_flag); + if (!is_constant_flag) { return SUCCESS; } - for (const auto &out_node : node->GetOutDataNodes()) { - GE_CHECK_NOTNULL(out_node); - if (out_node->GetType() == MERGE) { - return SUCCESS; - } - } - GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0)); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0))); - auto out_data_anchor = node->GetOutDataAnchor(0); + const auto &out_data_anchor = node->GetOutDataAnchor(0); GE_CHECK_NOTNULL(out_data_anchor); - for (auto peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor)); GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor)); } - - auto graph = node->GetOwnerComputeGraph(); - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, node)) + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node)); + AddNodeDeleted(node); AddRePassNodesWithInOut(in_node); return SUCCESS; diff --git a/ge/graph/passes/flow_ctrl_pass.cc b/ge/graph/passes/flow_ctrl_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/flow_ctrl_pass.h b/ge/graph/passes/flow_ctrl_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/folding_pass.cc b/ge/graph/passes/folding_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/folding_pass.h b/ge/graph/passes/folding_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/for_pass.cc b/ge/graph/passes/for_pass.cc index f3caea35..31dee390 100644 --- a/ge/graph/passes/for_pass.cc +++ b/ge/graph/passes/for_pass.cc @@ -37,6 +37,7 @@ namespace { const uint32_t kSubgraphLoopVarInputIndex = 0; const uint32_t kSubgraphInputIndex = 1; const uint32_t kWhileOutputIndex = 5; + const size_t kIDiffValue = 2; const std::string kAbs = "Abs"; } @@ -137,7 +138,7 @@ Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &n for_info.ctrl_inputs = std::move(ctrl_inputs); for_info.ctrl_outputs = std::move(ctrl_outputs); - GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); + GELOGI("Build for_info for node %s success.", node->GetName().c_str()); return SUCCESS; } @@ -159,13 +160,7 @@ OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index return nullptr; } - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); - return nullptr; - } - - return peer_out_anchor; + return in_data_anchor->GetPeerOutAnchor(); } /// @@ -186,20 +181,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vectorGetAllInDataAnchorsSize(); for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); - if (in_data_anchor == nullptr) { - GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); - return FAILED; - } - GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr, - GELOGW("Get null input by index %d from node %s ", - in_data_anchor->GetIdx(), node->GetName().c_str()); - continue); + GE_CHECK_NOTNULL(in_data_anchor); data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); } - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { std::vector peer_in_data_anchors; - for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { peer_in_data_anchors.emplace_back(peer_in_data_anchor); } data_outputs.emplace_back(peer_in_data_anchors); @@ -207,13 +195,13 @@ Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vectorGetInControlAnchor(); GE_CHECK_NOTNULL(in_ctrl_anchor); - for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { ctrl_inputs.emplace_back(peer_out_ctrl_anchor); } OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); GE_CHECK_NOTNULL(out_ctrl_anchor); - for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { ctrl_outputs.emplace_back(peer_in_ctrl_anchor); } @@ -707,7 +695,7 @@ Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) { } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { continue; } else { - input_mapping[i] = i - 2; + input_mapping[i] = i - kIDiffValue; } } for_body->UpdateInputMapping(input_mapping); diff --git a/ge/graph/passes/get_original_format_pass.h b/ge/graph/passes/get_original_format_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/global_step_insert_pass.h b/ge/graph/passes/global_step_insert_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/guarantee_const_pass.h b/ge/graph/passes/guarantee_const_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/identity_pass.cc b/ge/graph/passes/identity_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/input_output_connection_identify_pass.h b/ge/graph/passes/input_output_connection_identify_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/isolated_op_remove_pass.h b/ge/graph/passes/isolated_op_remove_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/link_gen_mask_nodes_pass.cc b/ge/graph/passes/link_gen_mask_nodes_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 8c9a0451..30fa1742 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -19,6 +19,8 @@ #include "graph/utils/tensor_utils.h" namespace ge { +const size_t kTwoInputNodesSize = 2; + Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto node_type = NodeUtils::GetNodeType(*node); @@ -52,7 +54,7 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { /// Enter-----------+ /// +-> Merge /// NextIteration---+ - if (input_nodes.size() == 2) { + if (input_nodes.size() == kTwoInputNodesSize) { if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { continue; } diff --git a/ge/graph/passes/memcpy_addr_async_pass.cc b/ge/graph/passes/memcpy_addr_async_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/memcpy_addr_async_pass.h b/ge/graph/passes/memcpy_addr_async_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index d2340037..26d82820 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -21,18 +21,16 @@ #include #include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" -using domi::PARAM_INVALID; -using domi::SUCCESS; - namespace ge { const int kValueIndexOutputIndex = 1; +const size_t kCaseNoInput = 0; +const size_t kCaseOneInput = 1; Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); @@ -47,15 +45,14 @@ Status MergePass::Run(NodePtr &node) { return SUCCESS; } - auto out_data_anchors = node->GetAllOutDataAnchors(); - if (out_data_anchors.empty()) { + if (node->GetAllOutDataAnchors().empty()) { GELOGE(PARAM_INVALID, "[%s] Merge node output anchor is empty", node->GetName().c_str()); return PARAM_INVALID; } - auto in_data_nodes = node->GetInDataNodes(); + const auto &in_data_nodes = node->GetInDataNodes(); switch (in_data_nodes.size()) { - case 0: { + case kCaseNoInput: { /// Case A: input_count = 0, the output of merge node is inactive as well /// In which case the output branch can be removed /// until another merge node is met @@ -70,7 +67,7 @@ Status MergePass::Run(NodePtr &node) { } return ret; } - case 1: { // Case B: input_count = 1, the merge node can be optimized out + case kCaseOneInput: { // Case B: input_count = 1, the merge node can be optimized out std::vector merge_io_map = {PassUtils::GetUniqueInDataAnchorIndex(node), -1}; if (merge_io_map[0] != -1 && IsNeedChangeIndexToConstant(node)) { int index = merge_io_map[0]; diff --git a/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/multi_batch_pass.cc b/ge/graph/passes/multi_batch_pass.cc index c7034612..74f7e30e 100644 --- a/ge/graph/passes/multi_batch_pass.cc +++ b/ge/graph/passes/multi_batch_pass.cc @@ -22,9 +22,6 @@ #include "graph/common/omg_util.h" #include "graph/utils/type_utils.h" -using std::string; -using std::vector; - namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); @@ -53,7 +50,7 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return FAILED; } std::vector> batch_shape; - vector> combined_batch; + std::vector> combined_batch; if (!CheckSwitchN(batch_shape, combined_batch)) { GELOGE(FAILED, "CheckSwitchN failed."); return FAILED; @@ -104,6 +101,7 @@ Status MultiBatchPass::ClearStatus() { /// Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) { const auto &func_desc = case_node->GetOpDesc(); + GE_CHECK_NOTNULL(func_desc); if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) { GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str()); return SUCCESS; @@ -114,7 +112,7 @@ Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]); GE_CHECK_NOTNULL(subgraph); - const string batch_label = "Batch_" + std::to_string(i); + const std::string batch_label = "Batch_" + std::to_string(i); for (const auto &node : subgraph->GetDirectNode()) { (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); } @@ -139,12 +137,12 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor continue; } - InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_data_anchor == nullptr) { GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); return FAILED; } - OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); + const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); if (pred_input == nullptr) { GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); return FAILED; @@ -178,12 +176,10 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor /// @return Status /// Status MultiBatchPass::GetDynamicType() { - for (const auto &switchn : switch_n_nodes_) { - auto switchn_desc = switchn->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_desc); + for (const auto &switch_n : switch_n_nodes_) { int32_t dynamic_type = static_cast(FIXED); - if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { - GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); + if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { + GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); return FAILED; } if (dynamic_type == static_cast(FIXED)) { @@ -191,7 +187,7 @@ Status MultiBatchPass::GetDynamicType() { return FAILED; } if (dynamic_type_ != static_cast(FIXED) && dynamic_type_ != dynamic_type) { - GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", + GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", dynamic_type, dynamic_type_); return FAILED; } @@ -212,21 +208,19 @@ Status MultiBatchPass::GetDynamicType() { Status MultiBatchPass::GetUserDesignateShape() { data_name_order_.clear(); bool first_check = true; - for (const auto &switchn : switch_n_nodes_) { - auto switchn_desc = switchn->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_desc); - vector cur_switchn_data_name_order; - if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) { - GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str()); + for (const auto &switch_n : switch_n_nodes_) { + std::vector cur_data_name_order; + if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { + GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); return FAILED; } if (first_check) { - data_name_order_ = cur_switchn_data_name_order; + data_name_order_ = cur_data_name_order; first_check = false; } else { - if (data_name_order_ != cur_switchn_data_name_order) { + if (data_name_order_ != cur_data_name_order) { GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", - switchn->GetName().c_str()); + switch_n->GetName().c_str()); return FAILED; } } @@ -245,7 +239,8 @@ Status MultiBatchPass::GetUserDesignateShape() { /// @param [out] combined_batch /// @return bool /// -bool MultiBatchPass::CheckSwitchN(vector> &batch_shape, vector> &combined_batch) { +bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape, + std::vector> &combined_batch) { // Check if output_num of different SwitchN is same uint32_t batch_num = 0; for (const NodePtr &node : switch_n_nodes_) { @@ -281,7 +276,8 @@ bool MultiBatchPass::CheckSwitchN(vector> &batch_shape, vector> &batch_shape, vector> &batch_shape, - vector> &combined_batch) { +bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector> &batch_shape, + std::vector> &combined_batch) { // Check if output_shape of different SwitchN is same - vector> idx_batch_shape; - vector> idx_combined_batch; + std::vector> idx_batch_shape; + std::vector> idx_combined_batch; for (uint32_t i = 0; i < batch_num; i++) { idx_batch_shape.clear(); idx_combined_batch.clear(); @@ -310,7 +306,7 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector> &b GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); return false; } - vector output_dims; + std::vector output_dims; if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); return false; @@ -385,8 +381,8 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { /// @return Status /// Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, - const vector> &batch_shape, - const vector> &combined_batch) { + const std::vector> &batch_shape, + const std::vector> &combined_batch) { NodePtr pred_value_node = pred_value->GetOwnerNode(); // Create SwitchCase node const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; @@ -429,31 +425,11 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s return false; } - size_t num = output_shape.size(); - size_t dim_num = output_shape[0].size(); - for (size_t i = 1; i < num; i++) { - size_t tmp_dim_num = output_shape[i].size(); - if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); + for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) { + if (output_shape[0] != *iter) { return false; } } - - if (dim_num == 0) { - return true; - } - - for (size_t i = 0; i < dim_num; i++) { - int64_t dim_value = output_shape[0][i]; - for (size_t j = 1; j < num; j++) { - int64_t tmp_dim_value = output_shape[j][i]; - if (dim_value != tmp_dim_value) { - GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, - dim_value, j, tmp_dim_value); - return false; - } - } - } return true; } @@ -468,8 +444,8 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s /// NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, - const vector> &batch_shape, - const vector> &combined_batch) { + const std::vector> &batch_shape, + const std::vector> &combined_batch) { OpDescPtr op_desc = MakeShared(name, STREAMSWITCHN); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); @@ -512,7 +488,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } - const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); + const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); return nullptr; diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/no_use_reshape_remove_pass.h b/ge/graph/passes/no_use_reshape_remove_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/parallel_concat_start_op_pass.cc b/ge/graph/passes/parallel_concat_start_op_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/parallel_concat_start_op_pass.h b/ge/graph/passes/parallel_concat_start_op_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/pass_utils.cc b/ge/graph/passes/pass_utils.cc index 5359ff63..3adfbde3 100644 --- a/ge/graph/passes/pass_utils.cc +++ b/ge/graph/passes/pass_utils.cc @@ -37,10 +37,6 @@ #include "graph/utils/type_utils.h" namespace ge { -namespace { -const uint32_t kShapeDimSize = 1; -const uint32_t DIM_SIZE_TWO = 2; -} // namespace Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector &data, std::vector &v_output, const bool scalar_output) { diff --git a/ge/graph/passes/pass_utils.h b/ge/graph/passes/pass_utils.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/permute_pass.h b/ge/graph/passes/permute_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/prevent_gradient_pass.h b/ge/graph/passes/prevent_gradient_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/print_op_pass.cc b/ge/graph/passes/print_op_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/print_op_pass.h b/ge/graph/passes/print_op_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/prune_pass.h b/ge/graph/passes/prune_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/reshape_remove_pass.cc b/ge/graph/passes/reshape_remove_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/resource_pair_add_control_pass.cc b/ge/graph/passes/resource_pair_add_control_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/resource_pair_remove_control_pass.cc b/ge/graph/passes/resource_pair_remove_control_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/ge/graph/passes/same_transdata_breadth_fusion_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/save_pass.cc b/ge/graph/passes/save_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/save_pass.h b/ge/graph/passes/save_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/shape_operate_op_remove_pass.cc b/ge/graph/passes/shape_operate_op_remove_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/shape_operate_op_remove_pass.h b/ge/graph/passes/shape_operate_op_remove_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/stop_gradient_pass.h b/ge/graph/passes/stop_gradient_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/subexpression_migration_pass.cc b/ge/graph/passes/subexpression_migration_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/subexpression_migration_pass.h b/ge/graph/passes/subexpression_migration_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/subgraph_const_migration_pass.h b/ge/graph/passes/subgraph_const_migration_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/subgraph_pass.cc b/ge/graph/passes/subgraph_pass.cc old mode 100755 new mode 100644 index 88e661a7..d1111d52 --- a/ge/graph/passes/subgraph_pass.cc +++ b/ge/graph/passes/subgraph_pass.cc @@ -149,10 +149,10 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node // 5. While->NetOutput in known subgraph std::string op_type; bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || - IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || - ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || - (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && - (kWhileOpTypes.count(in_node->GetType()) != 0)); + IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || + ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || + (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && + (kWhileOpTypes.count(in_node->GetType()) != 0)); if (insert_flag) { GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index 529480a6..f75a104f 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -72,25 +72,26 @@ Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &gra std::unordered_map> cond_switch_map; for (const NodePtr &node : graph->GetDirectNode()) { GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); - if ((type == SWITCH) || (type == REFSWITCH)) { - InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); - GE_CHECK_NOTNULL(in_cond_anchor); - OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); - return FAILED; - } + if ((type != SWITCH) && (type != REFSWITCH)) { + continue; + } + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + GE_CHECK_NOTNULL(in_cond_anchor); + OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); + return FAILED; + } - NodePtr cond_node = peer_out_anchor->GetOwnerNode(); - auto iter = cond_switch_map.find(cond_node); - if (iter == cond_switch_map.end()) { - cond_switch_map[cond_node] = { node }; - } else { - iter->second.emplace_back(node); - } - switch_nodes_.emplace_back(node); + NodePtr cond_node = peer_out_anchor->GetOwnerNode(); + auto iter = cond_switch_map.find(cond_node); + if (iter == cond_switch_map.end()) { + cond_switch_map[cond_node] = { node }; + } else { + iter->second.emplace_back(node); } + switch_nodes_.emplace_back(node); } MarkCycleDependence(cond_switch_map); @@ -241,10 +242,6 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou if (idx == SWITCH_DATA_INPUT) { peer_data_anchor = peer_out_anchor; } else { - if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); - return FAILED; - } peer_cond_anchor = peer_out_anchor; } } @@ -254,15 +251,14 @@ Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, Ou /// /// @brief Find Switch cond input -/// @param [in] pass_switch_flag /// @param [out] peer_cond_anchor /// @return Status /// -Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { +Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) { NodePtr tmp_node = nullptr; - string type; - bool need_pass_type = true; - while (need_pass_type) { + std::string type; + bool pass_flag = true; + while (pass_flag) { if (tmp_node == nullptr) { tmp_node = peer_cond_anchor->GetOwnerNode(); } else { @@ -274,7 +270,7 @@ Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutD } GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); - need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); + pass_flag = ((type == SWITCH) || (type == REFSWITCH)); } return SUCCESS; @@ -369,7 +365,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ } } else { int64_t switch_group_id = GetGroupId(stream_switch); - map>> switch_group_map; + std::map>> switch_group_map; std::list false_node_list; std::list true_node_list; std::list &node_list = true_branch_flag ? true_node_list : false_node_list; @@ -389,7 +385,7 @@ Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_ /// @return group_id /// int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { - string tailing_optimization_option; + std::string tailing_optimization_option; bool is_tailing_optimization = false; if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { // "1" means it's True from frontend option @@ -400,7 +396,7 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { return 0; } - string hccl_group_id; + std::string hccl_group_id; if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); return 0; @@ -432,6 +428,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); OutDataAnchorPtr peer_cond_anchor = iter->first; + GE_CHECK_NOTNULL(peer_cond_anchor); NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); @@ -549,6 +546,7 @@ NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, con NodePtr cast_node = graph->AddNode(cast_desc); GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); + // Cast node has and only has one input GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); return cast_node; @@ -614,24 +612,24 @@ Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_no return INTERNAL_ERROR; } - for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), "Remove ctl edge failed."); - GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), "Add ctl edge failed."); }); - GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); - if (same_cond_switch.count(in_ctl_node) > 0) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue); + if (same_cond_switch.count(in_ctrl_node) > 0) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), "Remove ctl edge failed."); continue; } - auto find_res1 = switch_node_map_.find(in_ctl_node); + auto find_res1 = switch_node_map_.find(in_ctrl_node); GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str()); return INTERNAL_ERROR; }); auto find_res2 = find_res1->second.find(orig_switch_name); diff --git a/ge/graph/passes/switch_to_stream_switch_pass.h b/ge/graph/passes/switch_to_stream_switch_pass.h index 48725230..05628871 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.h +++ b/ge/graph/passes/switch_to_stream_switch_pass.h @@ -42,9 +42,9 @@ namespace ge { +-----------+ +-----------+ | Const | | VariableV2| +-----------+ +-----------+ -*/ -/* Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input + + Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input +-----------+ / | task2 | \ @@ -131,11 +131,10 @@ class SwitchToStreamSwitchPass : public GraphPass { /// /// @brief Find Switch cond input - /// @param [in] pass_switch_flag /// @param [out] peer_cond_anchor /// @return Status /// - Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); + Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Create StreamSwitch Node diff --git a/ge/graph/passes/transop_breadth_fusion_pass.cc b/ge/graph/passes/transop_breadth_fusion_pass.cc index 689510f0..654c3822 100644 --- a/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -70,8 +70,10 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No trans_data_type = true; trans_format = true; trans_shape = true; - } else if (node->GetType() == RESHAPE) { + } else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) { trans_shape = true; + } else if (node->GetType() == REFORMAT) { + trans_format = true; } id << node->GetType() << '-' << anchor_index; diff --git a/ge/graph/passes/transop_breadth_fusion_pass.h b/ge/graph/passes/transop_breadth_fusion_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/transop_depth_fusion_pass.cc b/ge/graph/passes/transop_depth_fusion_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/transop_depth_fusion_pass.h b/ge/graph/passes/transop_depth_fusion_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h b/ge/graph/passes/transop_nearby_allreduce_fusion_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/ge/graph/passes/transop_without_reshape_fusion_pass.cc index d2b3f1b1..6bea9edc 100644 --- a/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -63,7 +63,7 @@ void TransOpWithoutReshapeFusionPass::SetRemainNode( continue; } GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str()); - GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); + GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return); } } @@ -594,7 +594,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde auto out_owner_node = out_peer_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node); auto out_peer_op_desc = out_owner_node->GetOpDesc(); - GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return ); + GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return); out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx()); auto in_peer_anchor = nodes_anchor.back().first; @@ -602,7 +602,7 @@ void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int inde auto in_owner_node = in_peer_anchor->GetOwnerNode(); GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node); auto in_peer_op_desc = in_owner_node->GetOpDesc(); - GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return ); + GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return); in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx()); } @@ -734,10 +734,14 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g continue; } - GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return ); + GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return); GELOGI("remove node:%s", node->GetName().c_str()); - if (graph->RemoveNode(node) != GRAPH_SUCCESS) { - GELOGW("remove node failed!node:%s", node->GetName().c_str()); + if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) { + GELOGW("Isolate node: %s failed.", node->GetName().c_str()); + continue; + } + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { + GELOGW("Remove node: %s failed.", node->GetName().c_str()); continue; } } diff --git a/ge/graph/passes/transop_without_reshape_fusion_pass.h b/ge/graph/passes/transop_without_reshape_fusion_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/transpose_transdata_pass.cc b/ge/graph/passes/transpose_transdata_pass.cc index 7348f143..2178eac7 100644 --- a/ge/graph/passes/transpose_transdata_pass.cc +++ b/ge/graph/passes/transpose_transdata_pass.cc @@ -217,11 +217,11 @@ void TransposeTransDataPass::CopyInputEdges(NodePtr &origin_node, NodePtr &new_n } OutDataAnchorPtr out_anchor = origin_node->GetInDataAnchor(0)->GetPeerOutAnchor(); new_in_data_anchor->UnlinkAll(); - GE_IF_BOOL_EXEC(new_in_data_anchor->LinkFrom(out_anchor) != GRAPH_SUCCESS, GELOGW("Link failed"); return ); + GE_IF_BOOL_EXEC(new_in_data_anchor->LinkFrom(out_anchor) != GRAPH_SUCCESS, GELOGW("Link failed"); return); // control anchor only link to control anchor GE_IF_BOOL_EXEC( - GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return ); + GraphUtils::CopyInCtrlEdges(origin_node, new_node) != GRAPH_SUCCESS, GELOGW("Copy in ctrl edges failed"); return); } bool TransposeTransDataPass::TransDataCheckAccuracySupported(const OpDescPtr &op_desc) { diff --git a/ge/graph/passes/unused_args_clean_pass.cc b/ge/graph/passes/unused_args_clean_pass.cc old mode 100755 new mode 100644 diff --git a/ge/graph/passes/unused_const_pass.h b/ge/graph/passes/unused_const_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/unused_op_remove_pass.h b/ge/graph/passes/unused_op_remove_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/var_is_initialized_op_pass.h b/ge/graph/passes/var_is_initialized_op_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/variable_format_pass.h b/ge/graph/passes/variable_format_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/variable_op_pass.h b/ge/graph/passes/variable_op_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/passes/variable_op_pass_bak.cc b/ge/graph/passes/variable_op_pass_bak.cc index 3e40e686..c9218296 100644 --- a/ge/graph/passes/variable_op_pass_bak.cc +++ b/ge/graph/passes/variable_op_pass_bak.cc @@ -252,7 +252,6 @@ Status VariableOpPass::RenewTransRoadDesc(const NodePtr &var, VarTransRoad &fusi // case 2: suppose input format of transdata not equal with out format // and input format not equal with var // so we make input format equal with var - for (auto &cur_trans : fusion_road) { if (cur_trans.input.GetFormat() == cur_trans.output.GetFormat()) { cur_trans.output.SetFormat(prev_node_info.output.GetFormat()); @@ -319,8 +318,8 @@ Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_roa } Status VariableOpPass::UpdateTransRoad(VarTransRoad &fusion_road, vector &first_path_trans_order, - map> &trans_type_to_changed_desc, - map> &trans_type_to_trans_ops){ + map> &trans_type_to_changed_desc, + map> &trans_type_to_trans_ops){ vector delete_trans_type; for (auto &trans_type : first_path_trans_order) { if (trans_type_to_changed_desc.find(trans_type) == trans_type_to_changed_desc.end()) { diff --git a/ge/graph/passes/variable_op_pass_bak.h b/ge/graph/passes/variable_op_pass_bak.h index b9fbb90e..fccd063b 100644 --- a/ge/graph/passes/variable_op_pass_bak.h +++ b/ge/graph/passes/variable_op_pass_bak.h @@ -45,8 +45,8 @@ class VariableOpPass : public GraphPass { private: Status UpdateTransRoad(VarTransRoad &fusion_road, vector &trans_road_order, - map> &trans_type_to_changed_desc, - map> &trans_type_to_trans_ops); + map> &trans_type_to_changed_desc, + map> &trans_type_to_trans_ops); Status DealFusion(const ge::NodePtr &var_node, VarTransRoad &fusion_road, map> trans_type_to_changed_desc, diff --git a/ge/graph/passes/variable_ref_delete_op_pass.h b/ge/graph/passes/variable_ref_delete_op_pass.h old mode 100755 new mode 100644 diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index b899ee83..2ee5e330 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -1621,7 +1621,8 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { if (desc.GetShape().GetDim(i) < 0) { - std::string situation = "data dim[" + std::to_string(i) + "][" + std::to_string(desc.GetShape().GetDim(i)) + "]" ; + std::string situation = "data dim[" + std::to_string(i) + "][" + + std::to_string(desc.GetShape().GetDim(i)) + "]" ; std::string reason = "it need >= 0"; ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, @@ -1701,7 +1702,7 @@ Status GraphPrepare::PrepareOptimize() { try { (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); - (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass" , new MarkAgnosticPass); + (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -1796,6 +1797,16 @@ Status GraphPrepare::PrepareOptimize() { } void GraphPrepare::TypeConversionOfConstant() { + bool is_acl_compile = false; + for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { + // This can ensure that n is not a null pointer + // No Conversion when called by aclOpCompile + (void)AttrUtils::GetBool(n->GetOpDesc(), ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, is_acl_compile); + if (is_acl_compile) { + return; + } + } + if (options_.train_graph_flag) { GELOGD("trans CONSTANT to CONSTANTOP in train."); for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { diff --git a/ge/graph/preprocess/graph_preprocess.h b/ge/graph/preprocess/graph_preprocess.h old mode 100755 new mode 100644 diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/ge/graph/preprocess/insert_op/ge_aipp_op.cc old mode 100755 new mode 100644 index 98712a82..7c8d9073 --- a/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -408,7 +408,7 @@ Status AippOp::ConvertRelatedInputNameToRank() { GE_CHECK_NOTNULL(aipp_params_); string related_input_name = aipp_params_->related_input_name(); - if(related_input_name.empty()) { + if (related_input_name.empty()) { return SUCCESS; } diff --git a/ge/graph/preprocess/insert_op/ge_aipp_op.h b/ge/graph/preprocess/insert_op/ge_aipp_op.h old mode 100755 new mode 100644 diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc old mode 100755 new mode 100644 index 1b926e4b..3b37003f --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -470,7 +470,7 @@ Status InsertNewOpUtil::UpdateDataBySwitchN(const NodePtr &switchn, const NodePt } } if (max_index >= switchn->GetOpDesc()->GetOutputsSize()) { - string error_msg = "No max size found from switchn node[" + switchn->GetName()+ "]"; + string error_msg = "No max size found from switchn node[" + switchn->GetName() + "]"; GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error_msg.c_str()); return INTERNAL_ERROR; } diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt index 97b5a0f5..d5ed7674 100644 --- a/ge/host_cpu_engine/CMakeLists.txt +++ b/ge/host_cpu_engine/CMakeLists.txt @@ -193,6 +193,7 @@ target_compile_options(host_cpu_opskernel_builder_static PRIVATE target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE google=ascend_private + LOG_CPP ) target_include_directories(host_cpu_opskernel_builder_static PRIVATE diff --git a/ge/host_kernels/add_kernel.h b/ge/host_kernels/add_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/broadcast_args_kernel.h b/ge/host_kernels/broadcast_args_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/broadcast_gradient_args_kernel.h b/ge/host_kernels/broadcast_gradient_args_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/cast_kernel.h b/ge/host_kernels/cast_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/concat_offset_kernel.h b/ge/host_kernels/concat_offset_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/concat_v2_kernel.cc b/ge/host_kernels/concat_v2_kernel.cc index a9f0da81..234d8c8a 100644 --- a/ge/host_kernels/concat_v2_kernel.cc +++ b/ge/host_kernels/concat_v2_kernel.cc @@ -120,7 +120,7 @@ Status ConcatV2Kernel::ConcatV2PreCompute(const std::vector &i int &tidx, ConstGeTensorPtr &tensor) { size_t input_size = input.size(); - // N >= 2 and N + 1 >= 3 + // N + 1 is greater than or equal to 3 if (input_size < kConcatV2InputNum) { GELOGI("The number of input for ConcatV2 must not be less than %zu.", kConcatV2InputNum); return NOT_CHANGED; diff --git a/ge/host_kernels/concat_v2_kernel.h b/ge/host_kernels/concat_v2_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/empty_kernel.h b/ge/host_kernels/empty_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/expanddims_kernel.h b/ge/host_kernels/expanddims_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/fill_kernel.h b/ge/host_kernels/fill_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/floordiv_kernel.cc b/ge/host_kernels/floordiv_kernel.cc index e254af09..df381212 100644 --- a/ge/host_kernels/floordiv_kernel.cc +++ b/ge/host_kernels/floordiv_kernel.cc @@ -112,8 +112,8 @@ void FloorDivKernel::ShapeCal(const std::vector &input, Ge template T FloorDivKernel::DivCal(const T &x_i, const T &y_i) { if ((x_i < static_cast(0)) != (y_i < static_cast(0))) { - T abs_x_i = std::abs(x_i); - T abs_y_i = std::abs(y_i); + T abs_x_i = x_i < 0 ? -x_i : x_i; + T abs_y_i = y_i < 0 ? -y_i : y_i; return static_cast(static_cast(-(abs_x_i + abs_y_i - 1) / abs_y_i)); } else { return static_cast(static_cast(x_i / y_i)); diff --git a/ge/host_kernels/floordiv_kernel.h b/ge/host_kernels/floordiv_kernel.h old mode 100755 new mode 100644 index d3dc3ff7..b8f6dd12 --- a/ge/host_kernels/floordiv_kernel.h +++ b/ge/host_kernels/floordiv_kernel.h @@ -40,10 +40,6 @@ class FloorDivKernel : public Kernel { template Status DataCal(const std::vector &input, ge::GeTensorPtr output_ptr); Status ComputeByDataType(DataType data_type, const std::vector &input, GeTensorPtr output_ptr); - - int64_t axis_dim_; - int64_t head_dim_; - int64_t end_dim_; }; } // namespace ge diff --git a/ge/host_kernels/floormod_kernel.h b/ge/host_kernels/floormod_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/gather_v2_kernel.cc b/ge/host_kernels/gather_v2_kernel.cc index e52b4534..ee73626b 100644 --- a/ge/host_kernels/gather_v2_kernel.cc +++ b/ge/host_kernels/gather_v2_kernel.cc @@ -40,6 +40,10 @@ const size_t kGatherV2InpotNum = 3; const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_ const std::set supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64}; +const int64_t DIM_AXIS_0 = 0; +const int64_t DIM_AXIS_1 = 1; +const int64_t DIM_AXIS_2 = 2; +const int64_t DIM_AXIS_3 = 3; } // namespace template Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) { @@ -191,16 +195,16 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x Status ret = SUCCESS; switch (axis) { - case 0: + case DIM_AXIS_0: ret = ProcessAxis0(tensor_x, output); break; - case 1: + case DIM_AXIS_1: ret = ProcessAxis1(tensor_x, output); break; - case 2: + case DIM_AXIS_2: ret = ProcessAxis2(tensor_x, output); break; - case 3: + case DIM_AXIS_3: ret = ProcessAxis3(tensor_x, output); break; default: diff --git a/ge/host_kernels/gather_v2_kernel.h b/ge/host_kernels/gather_v2_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/greater_kernel.h b/ge/host_kernels/greater_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/kernel_utils.cc b/ge/host_kernels/kernel_utils.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/kernel_utils.h b/ge/host_kernels/kernel_utils.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/maximum_kernel.h b/ge/host_kernels/maximum_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/mul_kernel.h b/ge/host_kernels/mul_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/pack_kernel.h b/ge/host_kernels/pack_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/permute_kernel.cc b/ge/host_kernels/permute_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/permute_kernel.h b/ge/host_kernels/permute_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/range_kernel.cc b/ge/host_kernels/range_kernel.cc index 32a72b47..97254fff 100644 --- a/ge/host_kernels/range_kernel.cc +++ b/ge/host_kernels/range_kernel.cc @@ -32,6 +32,9 @@ namespace ge { namespace { constexpr size_t kRangeInputNum = 3; constexpr uint32_t kRangeDimNum = 0; +constexpr size_t kStartIndex = 0; +constexpr size_t kLimitIndex = 1; +constexpr size_t kDeltaIndex = 2; const std::set kRangeSupportedType = {DT_INT32, DT_FLOAT}; } // namespace @@ -53,9 +56,9 @@ Status RangeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetTensorDesc().GetDataType(); if (data_type == DT_FLOAT) { if (GetRange(*reinterpret_cast(start->GetData().data()), diff --git a/ge/host_kernels/range_kernel.h b/ge/host_kernels/range_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/rank_kernel.cc b/ge/host_kernels/rank_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/rank_kernel.h b/ge/host_kernels/rank_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/reduce_prod_kernel.h b/ge/host_kernels/reduce_prod_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/reformat_kernel.h b/ge/host_kernels/reformat_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/reshape_kernel.h b/ge/host_kernels/reshape_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/rsqrt_kernel.cc b/ge/host_kernels/rsqrt_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/rsqrt_kernel.h b/ge/host_kernels/rsqrt_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/shape_kernel.h b/ge/host_kernels/shape_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/shape_n_kernel.h b/ge/host_kernels/shape_n_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/size_kernel.h b/ge/host_kernels/size_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/slice_d_kernel.h b/ge/host_kernels/slice_d_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/slice_kernel.h b/ge/host_kernels/slice_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/squeeze_kernel.h b/ge/host_kernels/squeeze_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/ssd_prior_box_kernel.cc b/ge/host_kernels/ssd_prior_box_kernel.cc index b3a0fc3e..3661fa9d 100644 --- a/ge/host_kernels/ssd_prior_box_kernel.cc +++ b/ge/host_kernels/ssd_prior_box_kernel.cc @@ -180,14 +180,18 @@ Status SsdPriorboxKernel::SetVariance(const vector &variance, const int d return SUCCESS; } -Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint aspect_ratios_size, uint min_sizes_size, uint max_sizes_size, - int layer_width, int layer_height, int &num_priors, +Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint32_t aspect_ratios_size, + uint32_t min_sizes_size, + uint32_t max_sizes_size, + int layer_width, + int layer_height, + int &num_priors, int &dim_size) const { if (ge::CheckUint32MulOverflow(min_sizes_size, aspect_ratios_size) != SUCCESS) { return PARAM_INVALID; } - uint tmp_value = aspect_ratios_size * min_sizes_size; + uint32_t tmp_value = aspect_ratios_size * min_sizes_size; if (ge::CheckUint32AddOverflow(tmp_value, max_sizes_size) != SUCCESS) { GELOGW("Failed to get list param."); return PARAM_INVALID; @@ -199,7 +203,7 @@ Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint aspect_ratios_size, uint mi return PARAM_INVALID; } num_priors = static_cast(tmp_value); - + if (ge::CheckIntMulOverflow(layer_width, layer_height) != SUCCESS) { GELOGW("Failed to get list param."); return PARAM_INVALID; @@ -288,7 +292,7 @@ std::unique_ptr SsdPriorboxKernel::BoundaryCalulate(int dim_size, int l } } - return std::move(output_data); + return output_data; } Status SsdPriorboxKernel::Compute(const NodePtr &node, std::vector &v_output) { diff --git a/ge/host_kernels/ssd_prior_box_kernel.h b/ge/host_kernels/ssd_prior_box_kernel.h old mode 100755 new mode 100644 index 0ebf221d..c08217e2 --- a/ge/host_kernels/ssd_prior_box_kernel.h +++ b/ge/host_kernels/ssd_prior_box_kernel.h @@ -100,8 +100,8 @@ class SsdPriorboxKernel : public Kernel { * @return OTHERS: Execution failed * @author */ - Status GetNumPriorAndDimSize(uint aspect_ratios_size, uint min_sizes_size, uint max_sizes_size, int layer_width, - int layer_height, int &num_priors, int &dim_size) const; + Status GetNumPriorAndDimSize(uint32_t aspect_ratios_size, uint32_t min_sizes_size, uint32_t max_sizes_size, + int layer_width, int layer_height, int &num_priors, int &dim_size) const; void DataCalulate(float x, float y, float box_x, float box_y, int img_x, int img_y, vector &result); std::unique_ptr BoundaryCalulate(int dim_size, int layer_width, int layer_height, float step_width, float step_height, int img_width, int img_height, float offset, diff --git a/ge/host_kernels/strided_slice_kernel.cc b/ge/host_kernels/strided_slice_kernel.cc index 2fe74415..b1bfb10a 100644 --- a/ge/host_kernels/strided_slice_kernel.cc +++ b/ge/host_kernels/strided_slice_kernel.cc @@ -272,6 +272,10 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector &x_dims) { auto begin_data_type_size = GetSizeByDataType(begin_tensor->GetTensorDesc().GetDataType()); + if (begin_data_type_size == 0) { + GELOGW("Param begin_data_type_size should not be zero."); + return; + } size_t begin_vec_size = begin_tensor->GetData().size() / begin_data_type_size; auto final_dim_num = x_dims_num < begin_vec_size ? begin_vec_size : x_dims_num; for (size_t i = 0; i < final_dim_num; i++) { @@ -284,8 +288,10 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten } void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, - const vector &x_dims, vector &orig_begin_vec, - vector &orig_end_vec, vector &orig_stride_vec) { + const vector &x_dims, + vector &orig_begin_vec, + vector &orig_end_vec, + vector &orig_stride_vec) { if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) { auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); @@ -308,7 +314,7 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num, if (orig_begin_vec.size() < x_dims_num) { for (size_t j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); - orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); + orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim + j)); orig_stride_vec.insert((orig_stride_vec.begin() + ellipsis_dim + j), 1); } } diff --git a/ge/host_kernels/strided_slice_kernel.h b/ge/host_kernels/strided_slice_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/sub_kernel.h b/ge/host_kernels/sub_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/transdata_kernel.h b/ge/host_kernels/transdata_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/transpose_kernel.cc b/ge/host_kernels/transpose_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/transpose_kernel.h b/ge/host_kernels/transpose_kernel.h old mode 100755 new mode 100644 diff --git a/ge/host_kernels/unpack_kernel.cc b/ge/host_kernels/unpack_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/host_kernels/unpack_kernel.h b/ge/host_kernels/unpack_kernel.h old mode 100755 new mode 100644 diff --git a/ge/hybrid/common/npu_memory_allocator.cc b/ge/hybrid/common/npu_memory_allocator.cc index f506caec..2c38367a 100644 --- a/ge/hybrid/common/npu_memory_allocator.cc +++ b/ge/hybrid/common/npu_memory_allocator.cc @@ -23,6 +23,8 @@ namespace ge { namespace hybrid { +const size_t kPaddingUnit = 2; + size_t kMaxHbmMemorySize = 1024UL * 1024UL * 1024UL * 1024UL; // 1024G std::map> NpuMemoryAllocator::allocators_; @@ -77,7 +79,7 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { } } // padding up to multiple of padding, and add extra padding - allocate_size = (size + 2 * padding - 1) / padding * padding; + allocate_size = (size + kPaddingUnit * padding - 1) / padding * padding; GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size); buffer = MemManager::Instance() .CachingInstance(RT_MEMORY_HBM) diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h index 0910d2c7..1fe40c77 100644 --- a/ge/hybrid/executor/hybrid_execution_context.h +++ b/ge/hybrid/executor/hybrid_execution_context.h @@ -57,7 +57,8 @@ struct GraphExecutionContext { do { \ if ((context != nullptr) && (context)->profiler != nullptr) { \ if (node_name != nullptr) { \ - context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GeLog::GetTid(), node_name, category, ##__VA_ARGS__);\ + context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GeLog::GetTid(), node_name, category, \ + ##__VA_ARGS__); \ } else { \ context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ }\ @@ -77,7 +78,7 @@ do { \ RECORD_PROFILING_EVENT((context), HybridProfiler::EXECUTION, fmt, "Execution", name, ##__VA_ARGS__) #define RECORD_CALLBACK_EVENT(context, name, fmt, ...) \ - RECORD_PROFILING_EVENT((context), HybridProfiler::CALLBACK, fmt, "Callback", name, ##__VA_ARGS__) + RECORD_PROFILING_EVENT((context), HybridProfiler::CALLBACKS, fmt, "Callback", name, ##__VA_ARGS__) } // namespace hybrid } // namespace ge #endif // GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 91996ab3..ba717a2d 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -379,11 +379,13 @@ Status HybridModelAsyncExecutor::Execute(const std::vector &inputs, } if (output_real_size > 0) { if (outputs[i].length < static_cast(output_real_size)) { - GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by user should be greater than or equal to the real size of output[%ld]", + GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by " + "user should be greater than or equal to the real size of output[%ld]", i, outputs[i].length, output_real_size); return FAILED; } - GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, RT_MEMCPY_DEVICE_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, + RT_MEMCPY_DEVICE_TO_DEVICE)); } outputs[i].length = output_real_size; } diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc old mode 100755 new mode 100644 index 4af34451..8ba687c2 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -82,7 +82,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, Status HybridModelExecutor::Cleanup() { GELOGD("Start to cleanup."); context_.callback_manager->Destroy(); - RuntimeInferenceContext::DestroyContext(to_string(context_.session_id)); + RuntimeInferenceContext::DestroyContext(std::to_string(context_.session_id)); GELOGD("Cleanup successfully."); return SUCCESS; } diff --git a/ge/hybrid/executor/hybrid_profiler.cc b/ge/hybrid/executor/hybrid_profiler.cc index 7228197f..336a633f 100644 --- a/ge/hybrid/executor/hybrid_profiler.cc +++ b/ge/hybrid/executor/hybrid_profiler.cc @@ -25,7 +25,7 @@ namespace ge { namespace hybrid { namespace { const int kMaxEvents = 10000; -const int kEventDescMax = 256; +const int kEventDescMax = 512; const int kMaxEventTypes = 8; const int kIndent = 8; } diff --git a/ge/hybrid/executor/hybrid_profiler.h b/ge/hybrid/executor/hybrid_profiler.h index 62ef9c73..94a042e4 100644 --- a/ge/hybrid/executor/hybrid_profiler.h +++ b/ge/hybrid/executor/hybrid_profiler.h @@ -33,7 +33,7 @@ class HybridProfiler { SHAPE_INFERENCE, COMPILE, EXECUTION, - CALLBACK, + CALLBACKS }; struct Event { diff --git a/ge/hybrid/executor/node_done_manager.cc b/ge/hybrid/executor/node_done_manager.cc index c0b0b17b..f0d4324a 100644 --- a/ge/hybrid/executor/node_done_manager.cc +++ b/ge/hybrid/executor/node_done_manager.cc @@ -21,7 +21,7 @@ namespace ge { namespace hybrid { namespace { -constexpr int kDefaultWaitTimeoutInSec = 60 * 10; +constexpr int kDefaultWaitTimeoutInSec = 600; } bool NodeDoneManager::Cond::Await() { std::unique_lock lk(cond_mu_); diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 48b2ed72..04f1ee4b 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -27,7 +27,7 @@ namespace ge { namespace hybrid { class NodeTask; -class GraphExecutionContext; +struct GraphExecutionContext; class SubgraphContext; class ShapeFuture { diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 76a6cc37..5a464f8e 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -93,6 +93,7 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vectorGetName().c_str(), i); GE_CHECK_LE(i + 1, input_desc.size()); const auto &tensor_desc = input_desc[i]; + GE_CHECK_NOTNULL(tensor_desc); auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); GE_CHECK_NOTNULL(node_state); node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape()); diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc old mode 100755 new mode 100644 index e6729352..b984eec3 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -260,8 +260,7 @@ Status NodeDoneCallback::ProfilingReport() { } auto &profiling_manager = ProfilingManager::Instance(); - profiling_manager.ReportProfilingData(model->GetModelId(), task_desc_info, compute_graph_info, - !profiling_manager.IsAclApiMode()); + profiling_manager.ReportProfilingData(model->GetModelId(), task_desc_info, compute_graph_info); return SUCCESS; } diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc old mode 100755 new mode 100644 index bd429b21..1d813526 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -62,7 +62,8 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { { std::lock_guard lk(mu_); RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); - GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "Invoke InferShapeAndType failed."); + GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), + "Invoke InferShapeAndType failed."); RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); } // Check again to make sure shape is valid after shape inference @@ -164,7 +165,7 @@ Status ShapeInferenceEngine::InferShapeForSubgraph(const NodeItem &node_item, co for (auto &it : fused_subgraph.input_mapping) { auto parent_tensor_desc = node_item.MutableInputDesc(it.first); GE_CHECK_NOTNULL(parent_tensor_desc); - GELOGD("Start to update shape by input[%u]", it.first); + GELOGD("Start to update shape by input[%d]", it.first); GELOGD("Update shape to [%s]", parent_tensor_desc->GetShape().ToString().c_str()); GELOGD("Update original shape to [%s]", parent_tensor_desc->GetOriginShape().ToString().c_str()); for (auto &tensor_desc : it.second) { @@ -183,12 +184,12 @@ Status ShapeInferenceEngine::InferShapeForSubgraph(const NodeItem &node_item, co } for (auto &it : fused_subgraph.output_mapping) { - uint32_t parent_output_idx = it.first; + int parent_output_idx = it.first; const auto &op_desc = it.second; GELOGD("Update parent output[%d] by [%s]", parent_output_idx, op_desc->GetName().c_str()); auto input_desc = op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(input_desc); - auto parent_output_tensor_desc = node_item.op_desc->MutableOutputDesc(parent_output_idx); + auto parent_output_tensor_desc = node_item.MutableOutputDesc(parent_output_idx); GE_CHECK_NOTNULL(parent_output_tensor_desc); GELOGD("Update shape to [%s]", input_desc->GetShape().ToString().c_str()); GELOGD("Update original shape to [%s]", input_desc->GetOriginShape().ToString().c_str()); diff --git a/ge/hybrid/executor/worker/task_compile_engine.cc b/ge/hybrid/executor/worker/task_compile_engine.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/hybrid_davinci_model.cc b/ge/hybrid/hybrid_davinci_model.cc old mode 100755 new mode 100644 index a491c9a5..7009331c --- a/ge/hybrid/hybrid_davinci_model.cc +++ b/ge/hybrid/hybrid_davinci_model.cc @@ -113,8 +113,8 @@ HybridDavinciModel::~HybridDavinciModel() { delete impl_; } -unique_ptr HybridDavinciModel::Create(const GeRootModelPtr &ge_root_model) { - auto instance = unique_ptr(new (std::nothrow)HybridDavinciModel()); +std::unique_ptr HybridDavinciModel::Create(const GeRootModelPtr &ge_root_model) { + auto instance = std::unique_ptr(new (std::nothrow)HybridDavinciModel()); if (instance != nullptr) { instance->impl_ = new (std::nothrow) HybridDavinciModel::Impl(ge_root_model); if (instance->impl_ != nullptr) { diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index feb6757b..132b0f8c 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -176,20 +176,9 @@ Status HybridModel::GetInputOutputDescInfo(vector &input_de return SUCCESS; } -void HybridModel::SetInputDimsAndShapeRangesInfo(const vector &model_input_dims, std::vector> &shape_ranges, - Format &format, InputOutputDescInfo &input) { - uint32_t n, c, h, w; - n = format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N; - c = format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C; - h = format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H; - w = format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W; - - if (model_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = model_input_dims[n]; - input.shape_info.height = model_input_dims[h]; - input.shape_info.width = model_input_dims[w]; - input.shape_info.channel = model_input_dims[c]; - } +void HybridModel::SetInputDimsAndShapeRangesInfo(const vector &model_input_dims, + std::vector> &shape_ranges, + InputOutputDescInfo &input) { for (auto model_input_dim : model_input_dims) { input.shape_info.dims.push_back(model_input_dim); } @@ -197,25 +186,25 @@ void HybridModel::SetInputDimsAndShapeRangesInfo(const vector &model_in return; } -void HybridModel::CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input) { +void HybridModel::CreateInputDimsInfo(const OpDescPtr &op_desc, InputOutputDescInfo &input) { std::vector> shape_ranges; if (is_new_model_desc_ && op_desc->HasAttr(ATTR_NAME_INPUT_DIMS)) { // When static aipp is set, need to get the model input dims which processed by aipp vector model_input_dims; (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_DIMS, model_input_dims); - SetInputDimsAndShapeRangesInfo(model_input_dims, shape_ranges, format, input); + SetInputDimsAndShapeRangesInfo(model_input_dims, shape_ranges, input); return; } // judge if this data is linked dynamic aipp first, multiply batch has been considered if (op_desc->HasAttr("_dynamic_aipp_input_dims")) { vector dynamic_aipp_input_dims; (void)AttrUtils::GetListInt(op_desc, "_dynamic_aipp_input_dims", dynamic_aipp_input_dims); - SetInputDimsAndShapeRangesInfo(dynamic_aipp_input_dims, shape_ranges, format, input); + SetInputDimsAndShapeRangesInfo(dynamic_aipp_input_dims, shape_ranges, input); return; } else { vector input_dims = op_desc->GetInputDescPtr(0)->GetShape().GetDims(); op_desc->GetInputDescPtr(0)->GetShapeRange(shape_ranges); - SetInputDimsAndShapeRangesInfo(input_dims, shape_ranges, format, input); + SetInputDimsAndShapeRangesInfo(input_dims, shape_ranges, input); return; } } @@ -248,7 +237,7 @@ Status HybridModel::GetInputDescInfo(vector &input_desc, st // not support dynamic shape input for now, so input_size here will be not less than zero. input.size = input_size; - CreateInputDimsInfo(op_desc, format, input); + CreateInputDimsInfo(op_desc, input); formats.push_back(format); input_desc.push_back(input); @@ -257,29 +246,15 @@ Status HybridModel::GetInputDescInfo(vector &input_desc, st return SUCCESS; } -void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDescInfo &output_desc_info, uint32_t &format_result) { +void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, + InputOutputDescInfo &output_desc_info, uint32_t &format_result) { GE_IF_BOOL_EXEC(output_desc == nullptr, GELOGE(FAILED, "output desc ptr is nullptr"); return ); Format format = output_desc->GetFormat(); GeShape shape = output_desc->GetShape(); std::vector> shape_ranges; output_desc->GetShapeRange(shape_ranges); DataType data_type = output_desc->GetDataType(); - int64_t dims[] = {1, 1, 1, 1}; format_result = format; - if (format == FORMAT_ND) { // for ND tensor - for (size_t i = 0; i < shape.GetDimNum() && i < (sizeof(dims) / sizeof(dims[0])); i++) { - dims[i] = shape.GetDim(i); - } - } else { // FOR FORMAT_NHWC or FORMAT_NCHW - dims[0] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N); // 0: first dim - dims[1] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C); // 1: second dim - dims[2] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H); // 2: third dim - dims[3] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W); // 3: forth dim - } - output_desc_info.shape_info.num = dims[0]; // 0: first dim - output_desc_info.shape_info.channel = dims[1]; // 1: second dim - output_desc_info.shape_info.height = dims[2]; // 2: third dim - output_desc_info.shape_info.width = dims[3]; // 3: forth dim if (format == FORMAT_FRACTAL_Z) { // FraczToHWCK int64_t k = shape.GetDim(0); // 0: first dim int64_t c = shape.GetDim(1); // 1: second dim @@ -310,7 +285,8 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDes Status HybridModel::GetOutputDescInfo(vector &output_desc, std::vector &formats) { std::vector output_desc_list; - GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed"); // output_desc_list contains vaild input desc + // output_desc_list contains vaild input desc + GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed"); vector out_node_names; (void)ge::AttrUtils::GetListStr(ge_root_model_->GetRootGraph(), ATTR_MODEL_OUT_NODES_NAME, out_node_names); @@ -320,7 +296,8 @@ Status HybridModel::GetOutputDescInfo(vector &output_desc, GE_CHECK_NOTNULL(op_desc); auto out_size = static_cast(op_desc->GetInputsSize()); - GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(), FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size()); + GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(), + FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size()); for (uint32_t index = 0; index < out_size; ++index) { string output_name; @@ -328,9 +305,11 @@ Status HybridModel::GetOutputDescInfo(vector &output_desc, std::vector src_index = op_desc->GetSrcIndex(); if (out_size == out_node_names.size()) { bool contains_colon = out_node_names[index].find(":") != std::string::npos; - output_name = contains_colon ? out_node_names[index] : out_node_names[index] + ":" + std::to_string(src_index[index]); + output_name = contains_colon ? out_node_names[index] : out_node_names[index] + + ":" + std::to_string(src_index[index]); } else { - output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); + output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + + "_" + std::to_string(src_index[index]); } InputOutputDescInfo output_desc_info; diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 1ec2f8a8..5fd5f8f5 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -100,12 +100,13 @@ class HybridModel { Status GetOutputDescInfo(vector &output_desc, std::vector &formats); - void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input); + void CreateInputDimsInfo(const OpDescPtr &op_desc, InputOutputDescInfo &input); void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; } - void SetInputDimsAndShapeRangesInfo(const vector &model_input_dims, std::vector> &shape_ranges, - Format &format, InputOutputDescInfo &input); + void SetInputDimsAndShapeRangesInfo(const vector &model_input_dims, + std::vector> &shape_ranges, + InputOutputDescInfo &input); private: friend class HybridModelBuilder; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc old mode 100755 new mode 100644 index f4da3dcf..d519c35b --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -35,7 +35,6 @@ namespace hybrid { namespace { const uint32_t kSubgraphIndex = 0U; const uint32_t kVarOutputIndex = 0U; -const uint32_t kAlignment = 32; const int kBytes = 8; const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown"; @@ -339,9 +338,9 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { uint32_t parent_index = 0; if (!AttrUtils::GetInt(*op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { GELOGE(INTERNAL_ERROR, - "[%s] Failed to get attr [%s]", - op_desc->GetName().c_str(), - ATTR_NAME_PARENT_NODE_INDEX.c_str()); + "[%s] Failed to get attr [%s]", + op_desc->GetName().c_str(), + ATTR_NAME_PARENT_NODE_INDEX.c_str()); return INTERNAL_ERROR; } @@ -793,7 +792,7 @@ Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr "Shape size is invalid"); auto offset = static_cast(elem_num * kBytes); auto hbm_raw_data_base_addr = - reinterpret_cast(reinterpret_cast(var_addr) + offset); + static_cast(reinterpret_cast(var_addr) + offset); for (int64_t i = elem_num - 1; i >= 0; --i) { buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); } @@ -987,7 +986,7 @@ Status HybridModelBuilder::IndexTaskDefs() { // index task defs GELOGD("To index tasks for subgraph: %s", name.c_str()); - unordered_map node_map; + std::unordered_map node_map; for (const auto &node : sub_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 8fac4a73..8fbdc648 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -30,8 +30,8 @@ class NodeTask; class NodeExecutor; struct FusedSubgraph { - std::map> input_mapping; - std::map output_mapping; + std::map> input_mapping; + std::map output_mapping; std::vector nodes; ComputeGraphPtr graph; }; diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc old mode 100755 new mode 100644 index 3b87c8b8..407210cf --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -15,7 +15,7 @@ */ #include "aicore_node_executor.h" -#include "cce/taskdown_common.hpp" +#include "framework/common/taskdown_common.h" #include "hybrid/executor/hybrid_execution_context.h" namespace ge { diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/ge/hybrid/node_executor/aicore/aicore_node_executor.h old mode 100755 new mode 100644 index 989090e9..9e92a160 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -89,7 +89,7 @@ class TaskCompilerFactory { class CompilerFunctionRegistrar { public: - CompilerFunctionRegistrar(CreateFn fn); + explicit CompilerFunctionRegistrar(CreateFn fn); ~CompilerFunctionRegistrar() = default; }; } // namespace hybrid diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 7ed14309..80ea579b 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -15,7 +15,7 @@ */ #include "hybrid/node_executor/aicore/aicore_op_task.h" -#include "cce/taskdown_common.hpp" +#include "framework/common/taskdown_common.h" #include "framework/common/debug/log.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/node_executor/aicore/aicore_task_builder.h" @@ -38,7 +38,7 @@ Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) } Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) { - auto op_desc_ptr = make_shared(op_desc); + auto op_desc_ptr = std::make_shared(op_desc); GE_CHECK_NOTNULL(op_desc_ptr); auto tbe_kernel = op_desc_ptr->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); if (tbe_kernel == nullptr) { @@ -151,8 +151,8 @@ Status AiCoreOpTask::ValidateTaskDef(const domi::TaskDef &task_def) { const domi::KernelDef &kernel_def = task_def.kernel(); const domi::KernelContext &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type != cce::ccKernelType::TE) { + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type != ccKernelType::TE) { GELOGE(INTERNAL_ERROR, "Invalid kernel type(%d) in AiCore TaskDef.", static_cast(kernel_type)); return INTERNAL_ERROR; } diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.h b/ge/hybrid/node_executor/aicore/aicore_op_task.h old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/ge/hybrid/node_executor/aicore/aicore_task_builder.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/ge/hybrid/node_executor/aicore/aicore_task_builder.h old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h old mode 100755 new mode 100644 index bf948349..b6dfd82b --- a/ge/hybrid/node_executor/aicore/aicore_task_compiler.h +++ b/ge/hybrid/node_executor/aicore/aicore_task_compiler.h @@ -26,7 +26,7 @@ namespace hybrid { class AiCoreTaskCompiler : public TaskCompiler { public: AiCoreTaskCompiler() = default; - ~AiCoreTaskCompiler() = default; + ~AiCoreTaskCompiler() override = default; Status CompileOp(const NodePtr &node, std::vector &tasks) override; Status Initialize() override; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc old mode 100755 new mode 100644 index 1a47e525..7330f616 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -15,7 +15,7 @@ */ #include "hybrid/node_executor/aicpu/aicpu_node_executor.h" -#include "cce/taskdown_common.hpp" +#include "framework/common/taskdown_common.h" #include "common/formats/formats.h" #include "aicpu/common/aicpu_task_struct.h" #include "graph/load/new_model_manager/model_manager.h" @@ -642,10 +642,14 @@ Status AicpuNodeTask::Init(const HybridModel &model) { const std::string &so_name = kernel_def.so_name(); const OpDescPtr op_desc = node_item_->GetOpDesc(); const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { - GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name), "load cust aicpu so failed."); - GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::CUST_AI_CPU) { + bool loaded = false; + GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name, loaded), + "load cust aicpu so failed."); + if (!loaded) { + GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); + } } GE_CHK_BOOL_RET_STATUS(args.size() == args_size_, FAILED, @@ -723,9 +727,9 @@ Status AicpuNodeTask::UpdateIoAddr(TaskContext &context) { auto io_addr = args_.get() + sizeof(aicpu::AicpuParamHead); // if has input and output, need copy to ioaddr - error_t cpy_ret = memcpy_s(io_addr, args_size_ - sizeof(aicpu::AicpuParamHead), - &io_addrs[0], sizeof(uint64_t) * io_addrs.size()); - GE_CHK_BOOL_RET_STATUS(cpy_ret == EOK, INTERNAL_ERROR, + int cpy_ret = memcpy_s(io_addr, args_size_ - sizeof(aicpu::AicpuParamHead), + &io_addrs[0], sizeof(uint64_t) * io_addrs.size()); + GE_CHK_BOOL_RET_STATUS(cpy_ret == 0, INTERNAL_ERROR, "Node[%s] memcpy io addr to AicpuParamHead failed, ret=%d, args_size=%u, io nums=%zu.", node_name_.c_str(), cpy_ret, args_size_, io_addrs.size()); return SUCCESS; @@ -736,9 +740,9 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { const auto &so_name = task_def_.kernel().so_name(); const auto &kernel_name = task_def_.kernel().kernel_name(); const auto &kcontext = task_def_.kernel().context(); - auto kernel_type = static_cast(kcontext.kernel_type()); + auto kernel_type = static_cast(kcontext.kernel_type()); uint32_t flag = RT_KERNEL_DEFAULT; - if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { + if (kernel_type == ccKernelType::CUST_AI_CPU) { flag |= static_cast(RT_KERNEL_CUSTOM_AICPU); } auto rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name.c_str()), diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index b984cc86..1205b190 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -37,6 +37,8 @@ class AicpuNodeTaskBase : public NodeTask { ~AicpuNodeTaskBase() override = default; + using NodeTask::Init; + virtual Status Init(const HybridModel &model) = 0; Status UpdateArgs(TaskContext &context) override; diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h index fb1966b4..2dde993b 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -27,7 +27,7 @@ class HybridModel; class KnownNodeTask : public NodeTask { public: - KnownNodeTask(std::shared_ptr davinci_model) + explicit KnownNodeTask(std::shared_ptr davinci_model) : davinci_model_(davinci_model) {} diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.cc b/ge/hybrid/node_executor/controlop/control_op_executor.cc index 83fc09ee..74920b22 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.cc +++ b/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -405,7 +405,7 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, auto node_item = model.GetNodeItem(node); GE_CHECK_NOTNULL(node_item); - unique_ptr node_task; + std::unique_ptr node_task; auto node_type = node->GetType(); if (node_type == IF || node_type == STATELESSIF) { node_task.reset(new(std::nothrow) IfOpNodeTask()); diff --git a/ge/hybrid/node_executor/controlop/control_op_executor.h b/ge/hybrid/node_executor/controlop/control_op_executor.h index 7520afd1..3becfaaa 100644 --- a/ge/hybrid/node_executor/controlop/control_op_executor.h +++ b/ge/hybrid/node_executor/controlop/control_op_executor.h @@ -25,6 +25,7 @@ namespace ge { namespace hybrid { class ControlOpNodeTask : public NodeTask { public: + using NodeTask::Init; virtual Status Init(const NodePtr &node, const HybridModel &model) = 0; Status UpdateArgs(TaskContext &context) override; diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc old mode 100755 new mode 100644 index ee45964c..a52e5670 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -61,18 +61,18 @@ Status RefInputTask::Execute(TaskContext &context) { Status RefInputTask::RefOneByOne(TaskContext &context) { GELOGI("node %s type %s ref input one by one begin.", node_name_.c_str(), node_type_.c_str()); - uint32_t input_num = context.NumInputs(); - uint32_t output_num = context.NumOutputs(); + int input_num = context.NumInputs(); + int output_num = context.NumOutputs(); if (output_num > input_num) { - GELOGE(INTERNAL_ERROR, "node %s type %s has %u outputs but only %u inputs, can't ref one by one.", + GELOGE(INTERNAL_ERROR, "node %s type %s has %d outputs but only %d inputs, can't ref one by one.", node_name_.c_str(), node_type_.c_str(), output_num, input_num); return INTERNAL_ERROR; } - for (uint32_t out_index = 0; out_index < output_num; ++out_index) { + for (uint32_t out_index = 0; out_index < static_cast(output_num); ++out_index) { auto input = context.GetInput(out_index); GE_CHECK_NOTNULL(input); GE_CHK_STATUS_RET(context.SetOutput(out_index, *input)); - GELOGD("node %s type %s output[%u] ref input[%u] addr=%p.", + GELOGD("node %s type %s output[%d] ref input[%d] addr=%p.", node_name_.c_str(), node_type_.c_str(), out_index, out_index, input->GetData()); } GELOGI("node %s type %s ref input one by one end.", node_name_.c_str(), node_type_.c_str()); diff --git a/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc b/ge/hybrid/node_executor/host_cpu/host_cpu_node_executor.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc index 3bf71013..01fd391d 100644 --- a/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc +++ b/ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc @@ -20,7 +20,6 @@ #include "hybrid/node_executor/host_cpu/kernel_factory.h" namespace { -const size_t kAssignInputNum = 2; const size_t kAssignRefInputIndex = 0; const size_t kAssignValueInputIndex = 1; const size_t kAssignRefOutputIndex = 0; diff --git a/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc b/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h b/ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.h old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc old mode 100755 new mode 100644 index e577f09b..95e50c31 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -34,7 +34,6 @@ const char *const kEngineNameAiCpuTf = "aicpu_tf_kernel"; const char *const kEngineNameHccl = "ops_kernel_info_hccl"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; -const char *const kOwnerGraphIsUnknown = "OwnerGraphIsUnknown"; } Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); diff --git a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc old mode 100755 new mode 100644 diff --git a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h index 9ea544a1..73873002 100644 --- a/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h +++ b/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h @@ -41,7 +41,6 @@ class PartitionedCallNodeTask : public NodeTask { const GraphItem *graph_item_; std::unique_ptr subgraph_executor_; - GraphExecutionContext *context_ = nullptr; }; class PartitionedCallNodeExecutor : public NodeExecutor { diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index b7152878..77004f99 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -233,9 +233,7 @@ Status TaskContext::AllocateOutput(int index, } else { GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", - node_item_->NodeName().c_str(), - index, - outputs_start_[index].GetSize()); + node_item_->NodeName().c_str(), index, outputs_start_[index].GetSize()); } } } diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index 2cff0536..0549a1dc 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -29,7 +29,7 @@ namespace ge { namespace hybrid { -class GraphExecutionContext; +struct GraphExecutionContext; class SubgraphContext; class TaskContext { diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc old mode 100755 new mode 100644 index 306a804a..92700179 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -485,11 +485,9 @@ Status GELib::Finalize() { void GELib::ShutDownProfiling() { std::lock_guard lock(status_mutex_); - if (!ProfilingManager::Instance().ProfilingOpTraceOn() && ProfilingManager::Instance().ProfilingOn()) { - ProfilingManager::Instance().StopProfiling(); - } if (ProfilingManager::Instance().ProfilingOn()) { - ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE); + ProfilingManager::Instance().StopProfiling(); + ProfilingManager::Instance().PluginUnInit(); } } diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc old mode 100755 new mode 100644 diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 74aa6a60..f181170c 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -49,6 +49,8 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; +const std::string kInputShape = "input_shape"; +const std::string kInputFormat = "input_format"; } // namespace static graphStatus CheckGlobalOptions(std::map &global_options) { @@ -225,7 +227,9 @@ class Impl { ~Impl() { (void)generator_.Finalize(); }; graphStatus CheckOptions(const std::map &options); graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs); - graphStatus Init(const std::map &options); + graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape); + graphStatus UpdateDataOpAttr(const Graph &graph); + graphStatus Init(const Graph &graph, const std::map &options); graphStatus BuildModel(const Graph &graph, const std::map &options, ModelBufferData &ge_models); graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, @@ -240,6 +244,40 @@ class Impl { OmgContext omg_context_; }; +graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { + GELOGD("Enter Update Data Attr Process!"); + if (options_.find(kInputShape) == options_.end()) { + return GRAPH_SUCCESS; + } + unordered_map> shape_map; + vector>> user_shape_map; + GE_CHK_BOOL_EXEC(ParseInputShape(options_[kInputShape], shape_map, user_shape_map, true), + return GRAPH_PARAM_INVALID, "parse input shape failed!"); + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + auto tensor_input = op->MutableInputDesc(0); + auto tensor_output = op->MutableOutputDesc(0); + GE_CHECK_NOTNULL(tensor_input); + GE_CHECK_NOTNULL(tensor_output); + string data_op_name = op->GetName(); + auto iter = shape_map.find(data_op_name); + if (iter != shape_map.end()) { + tensor_input->SetShape(ge::GeShape(iter->second)); + tensor_output->SetShape(ge::GeShape(iter->second)); + GELOGD("update input [%s] shape info", data_op_name.c_str()); + } else { + GELOGI("no need update input [%s] attr because not found from input_shape.", data_op_name.c_str()); + } + } + } + return GRAPH_SUCCESS; +} + graphStatus Impl::CheckOptions(const std::map &options) { for (auto &ele : options) { auto it = ge::ir_option::ir_builder_suppported_options.find(ele.first); @@ -275,17 +313,61 @@ graphStatus Impl::CheckOptions(const std::map &options return GRAPH_PARAM_INVALID; } } + // Check option EXEC_DISABLE_REUSED_MEMORY + it = options_.find(ge::ir_option::EXEC_DISABLE_REUSED_MEMORY); + if (it != options_.end() && (CheckDisableReuseMemoryParamValid(it->second) != GRAPH_SUCCESS)) { + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +graphStatus Impl::GetDefaultInputShape(const Graph &graph, string &default_shape) { + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(input_node); + ge::OpDescPtr op = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + if (op->GetType() == DATA) { + string data_op_name = op->GetName(); + GELOGD("Data op name: %s, data op inputDesc size: %zu", data_op_name.c_str(), op->GetAllInputsDesc().size()); + ge::GeTensorDesc tensor = op->GetInputDesc(0); + ge::GeShape data_shape = tensor.GetShape(); + GELOGD("Data op get shape from InputDesc in ge ir graph."); + + string tmp_shape_str; + const std::vector &tmp_shape = data_shape.GetDims(); + if (tmp_shape.empty()) { + GELOGW("Data op: %s has zero shape dims!", data_op_name.c_str()); + } else { + tmp_shape_str += data_op_name + ":"; + for (auto tmp_dim : tmp_shape) { + tmp_shape_str += to_string((long)tmp_dim) + ","; + } + tmp_shape_str = tmp_shape_str.substr(0, tmp_shape_str.size() - 1); + tmp_shape_str += ";"; + default_shape += tmp_shape_str; + } + + GELOGD("Data op name: %s, data shape: %s.", data_op_name.c_str(), tmp_shape_str.c_str()); + } + } + default_shape = (default_shape.empty() ? default_shape : default_shape.substr(0, default_shape.size() - 1)); + GELOGI("Get default data op shape: %s from ge ir graph.", default_shape.c_str()); return GRAPH_SUCCESS; } -graphStatus Impl::Init(const std::map &options) { +graphStatus Impl::Init(const Graph &graph, const std::map &options) { // 1. check options graphStatus ret = CheckOptions(options); if (ret != GRAPH_SUCCESS) { GELOGE(ret, "User input options are illegal! Please check!"); return ret; } - + ret = UpdateDataOpAttr(graph); + if (ret != GRAPH_SUCCESS) { + return ret; + } std::string build_mode = (options_.find(BUILD_MODE) == options_.end() || options_[BUILD_MODE] == BUILD_MODE_NORMAL) ? "" : options_[BUILD_MODE]; options_[BUILD_MODE] = build_mode; @@ -296,7 +378,13 @@ graphStatus Impl::Init(const std::map &options) { GE_CHK_BOOL_RET_STATUS_NOLOG(ge::CheckLogParamValidAndSetLogLevel(log) == 0, GRAPH_PARAM_INVALID); options_[ge::ir_option::LOG_LEVEL] = log; - string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; + string input_shape; + if (options_.find("input_shape") == options_.end()) { + GE_CHK_BOOL_EXEC(GetDefaultInputShape(graph, input_shape) == ge::SUCCESS, + return ge::GRAPH_PARAM_INVALID, "Get default data op shape from graph failed!"); + } else { + input_shape = options_["input_shape"]; + } string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"]; string dynamic_batch_size = options_.find(ge::ir_option::DYNAMIC_BATCH_SIZE) == options_.end() @@ -416,7 +504,7 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector &options, ModelBufferData &model) { // 1. init GeGenerator with user optios - graphStatus ret = Init(options); + graphStatus ret = Init(graph, options); if (ret != GRAPH_SUCCESS) { GELOGE(ret, "Build ir model Init failed!"); return ret; @@ -502,7 +590,7 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m GELOGE(GRAPH_PARAM_INVALID, "input model is illegal"); return GRAPH_PARAM_INVALID; } - return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast(model.data.get()), + return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast(model.data.get()), static_cast(model.length)); } @@ -517,7 +605,7 @@ graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &mod return GRAPH_PARAM_INVALID; } std::string str_output_file = output_file; - return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast(model.data.get()), + return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast(model.data.get()), static_cast(model.length)); } @@ -543,7 +631,7 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { } auto ret = compute_graph->TopologicalSorting(); - if(ret != GRAPH_SUCCESS) { + if (ret != GRAPH_SUCCESS) { GELOGE(ret, "Acl topo logical sort failed."); return ret; } @@ -622,4 +710,52 @@ graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const siz return GRAPH_SUCCESS; } +graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector &inputs, + const vector &outputs, Graph &graph) { + auto op_type_str = std::string(op_type.GetString()); + auto op_name = op_type_str + "_" + std::to_string(ge::GetCurrentTimestamp()); + auto op_desc = ge::MakeShared(op_name, op_type_str); + GE_CHECK_NOTNULL(op_desc); + + // convert input tensordesc to getensor + std::vector input_tensors; + for (const auto &input : inputs) { + ge::GeTensorDesc tensor_desc(ge::GeShape(input.GetShape().GetDims()), input.GetFormat(), input.GetDataType()); + + tensor_desc.SetOriginFormat(input.GetFormat()); + ge::TensorUtils::SetRealDimCnt(tensor_desc, static_cast(input.GetShape().GetDims().size())); + ge::TensorUtils::SetInputTensor(tensor_desc, true); + ge::TensorUtils::SetOutputTensor(tensor_desc, false); + + if (op_desc->AddInputDesc(tensor_desc) != ge::GRAPH_SUCCESS) { + GELOGE(ge::FAILED, "AddInputDesc fail."); + return ge::FAILED; + } + input_tensors.emplace_back(tensor_desc); + } + + // convert output tensordesc to getensor + std::vector output_tensors; + for (const auto &output : outputs) { + ge::GeTensorDesc tensor_desc(ge::GeShape(output.GetShape().GetDims()), output.GetFormat(), output.GetDataType()); + + tensor_desc.SetOriginFormat(output.GetFormat()); + ge::TensorUtils::SetRealDimCnt(tensor_desc, static_cast(output.GetShape().GetDims().size())); + ge::TensorUtils::SetInputTensor(tensor_desc, false); + ge::TensorUtils::SetOutputTensor(tensor_desc, true); + + (void)op_desc->AddOutputDesc(tensor_desc); + output_tensors.emplace_back(tensor_desc); + } + + // call api to get graph + ge::GeGenerator generator; + std::string graph_name = ge::CurrentTimeInStr() + "_graph"; + if (generator.BuildSingleOpGraph(op_desc, input_tensors, output_tensors, graph_name, graph) != ge::SUCCESS) { + GELOGE(GRAPH_FAILED, "make graph fail."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + } // namespace ge diff --git a/ge/model/ge_model.cc b/ge/model/ge_model.cc old mode 100755 new mode 100644 diff --git a/ge/model/ge_model.h b/ge/model/ge_model.h old mode 100755 new mode 100644 diff --git a/ge/model/ge_root_model.h b/ge/model/ge_root_model.h old mode 100755 new mode 100644 index 53174064..aa5a4d47 --- a/ge/model/ge_root_model.h +++ b/ge/model/ge_root_model.h @@ -23,6 +23,7 @@ namespace ge { class GeRootModel { public: + GeRootModel() = default; explicit GeRootModel(ComputeGraphPtr &root_graph) : root_graph_(root_graph), model_id_(INVALID_MODEL_ID) {}; ~GeRootModel() = default; @@ -35,11 +36,11 @@ class GeRootModel { void SetModelId(uint32_t model_id) { model_id_ = model_id; } uint32_t GetModelId() const { return model_id_; } Status CheckIsUnknownShape(bool &is_dynamic_shape); - + void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; } private: - ComputeGraphPtr root_graph_; + ComputeGraphPtr root_graph_ = nullptr; std::map subgraph_instance_name_to_model_; - uint32_t model_id_; + uint32_t model_id_ = 0; }; } // namespace ge using GeRootModelPtr = std::shared_ptr; diff --git a/ge/module.mk b/ge/module.mk old mode 100755 new mode 100644 diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index 49af37c0..af259ecb 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -11,13 +11,13 @@ set(SRC_LIST "main.cc" "single_op_parser.cc" "../session/omg.cc" - "../ir_build/atc_ir_common.cc" + "../ir_build/atc_ir_common.cc" ) ############ atc ############ add_executable(atc ${SRC_LIST} ${PROTO_HDRS}) -target_compile_options(atc PRIVATE +target_compile_options(atc PRIVATE -Werror -O2 -Wno-deprecated-declarations @@ -27,6 +27,7 @@ target_compile_definitions(atc PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 COMPILE_OMG_PACKAGE google=ascend_private + LOG_CPP ) target_include_directories(atc PRIVATE @@ -74,6 +75,138 @@ target_link_libraries(atc PRIVATE -ldl ) +############ atc_atc.bin ############ +add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(atc_atc.bin PRIVATE + -Werror + -O2 + -Wno-deprecated-declarations +) + +target_compile_definitions(atc_atc.bin PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + COMPILE_OMG_PACKAGE + google=ascend_private + LOG_CPP +) + +target_include_directories(atc_atc.bin PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/common/inc/external + ${GE_CODE_DIR}/common/inc/external/graph + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + ${PARSER_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/common + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(atc_atc.bin PRIVATE + $ + ascend_protobuf + ge_common + register + c_sec + graph + error_manager + ge_compiler + parser_common + gflags + json + runtime_compile + slog + static_mmpa + -lrt + -ldl +) + +set_target_properties(atc_atc.bin PROPERTIES + OUTPUT_NAME atc.bin + RUNTIME_OUTPUT_DIRECTORY atclib +) + +############ fwk_atc.bin ############ +add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(fwk_atc.bin PRIVATE + -Werror + -O2 + -Wno-deprecated-declarations +) + +target_compile_definitions(fwk_atc.bin PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + COMPILE_OMG_PACKAGE + google=ascend_private + LOG_CPP +) + +target_include_directories(fwk_atc.bin PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/common/inc/external + ${GE_CODE_DIR}/common/inc/external/graph + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/register + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/external/register + ${PARSER_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + ${GE_CODE_DIR}/../inc/common + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc + ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain +) + +target_link_libraries(fwk_atc.bin PRIVATE + $ + ascend_protobuf + ge_common + register + c_sec + graph + error_manager + ge_runner + parser_common + gflags + json + runtime + slog + static_mmpa + -lrt + -ldl +) + +set_target_properties(fwk_atc.bin PROPERTIES + OUTPUT_NAME atc.bin + RUNTIME_OUTPUT_DIRECTORY fwkacl +) + ############ install ############ set(INSTALL_BASE_DIR "") set(INSTALL_LIBRARY_DIR lib) @@ -81,3 +214,11 @@ set(INSTALL_LIBRARY_DIR lib) install(TARGETS atc OPTIONAL LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} ) + +install(TARGETS atc_atc.bin OPTIONAL + RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib +) + +install(TARGETS fwk_atc.bin OPTIONAL + RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl +) diff --git a/ge/offline/atc b/ge/offline/atc new file mode 100644 index 00000000..05c65c26 --- /dev/null +++ b/ge/offline/atc @@ -0,0 +1,21 @@ +#!/bin/bash +#------------------------------------------------------------------- +# Purpose: +# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. +#------------------------------------------------------------------- + +real_path=$(readlink "$0") +if [ $? -eq 0 ]; then + LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd) +else + LOCAL_PATH=$(cd "$(dirname "$0")"; pwd) +fi +PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd) +LIB_P="/lib64" +PYTHON_P="/python/site-packages" +LIB64_PATH="${PKG_PATH}${LIB_P}" +PYTHON_PATH="${PKG_PATH}${PYTHON_P}" +export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}" +export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}" + +${PKG_PATH}/bin/atc.bin "$@" diff --git a/ge/offline/main.cc b/ge/offline/main.cc old mode 100755 new mode 100644 index 76494c68..b7188a85 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -68,7 +68,7 @@ const char *const kModeSupport = "only support 0(model to framework model), " const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"; // limit available mem size 2G -const long kMinAvailableMem = 2 * 1024 * 1024; +const long kMinAvailableMem = 2097152; // 2 * 1024 * 1024 DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); diff --git a/ge/offline/module.mk b/ge/offline/module.mk old mode 100755 new mode 100644 index 8859df29..5c7a919c --- a/ge/offline/module.mk +++ b/ge/offline/module.mk @@ -54,3 +54,108 @@ LOCAL_LDFLAGS := -lrt -ldl include $(BUILD_HOST_EXECUTABLE) +include $(CLEAR_VARS) + +LOCAL_MODULE := atclib/atc.bin + +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private + +LOCAL_SRC_FILES := \ + main.cc \ + single_op_parser.cc \ + ../session/omg.cc \ + ../ir_build/atc_ir_common.cc \ + +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH)/../ ./ \ + $(TOPDIR)inc \ + $(TOPDIR)metadef/inc \ + $(TOPDIR)graphengine/inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)metadef/inc/external \ + $(TOPDIR)graphengine/inc/external \ + $(TOPDIR)metadef/inc/external/graph \ + $(TOPDIR)graphengine/inc/framework \ + $(TOPDIR)libc_sec/include \ + $(TOPDIR)metadef/inc/common/util \ + $(TOPDIR)parser \ + third_party/json/include \ + third_party/gflags/include \ + third_party/protobuf/include \ + proto/om.proto \ + proto/ge_ir.proto \ + proto/task.proto \ + proto/insert_op.proto \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libge_common \ + libascend_protobuf \ + libslog \ + libgraph \ + libregister \ + liberror_manager \ + libge_compiler \ + libruntime_compile \ + libparser_common \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := libgflags + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_EXECUTABLE) + +include $(CLEAR_VARS) + +LOCAL_MODULE := fwkacl/atc.bin + +LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private + +LOCAL_SRC_FILES := \ + main.cc \ + single_op_parser.cc \ + ../session/omg.cc \ + ../ir_build/atc_ir_common.cc \ + +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH)/../ ./ \ + $(TOPDIR)inc \ + $(TOPDIR)metadef/inc \ + $(TOPDIR)graphengine/inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)metadef/inc/external \ + $(TOPDIR)graphengine/inc/external \ + $(TOPDIR)metadef/inc/external/graph \ + $(TOPDIR)graphengine/inc/framework \ + $(TOPDIR)libc_sec/include \ + $(TOPDIR)metadef/inc/common/util \ + $(TOPDIR)parser \ + third_party/json/include \ + third_party/gflags/include \ + third_party/protobuf/include \ + proto/om.proto \ + proto/ge_ir.proto \ + proto/task.proto \ + proto/insert_op.proto \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libge_common \ + libascend_protobuf \ + libslog \ + libgraph \ + libregister \ + liberror_manager \ + libge_runner \ + libruntime \ + libparser_common \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := libgflags + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_EXECUTABLE) diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index d4b9c1c9..b1e0da6d 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -27,6 +27,7 @@ #include "common/ge_inner_error_codes.h" #include "framework/common/util.h" #include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/operator_factory_impl.h" @@ -176,6 +177,7 @@ T GetValue(const map &dict, string &key, T default_val) { } void from_json(const Json &j, SingleOpTensorDesc &desc) { + bool is_tensor_valid = true; desc.dims = j.at(kKeyShape).get>(); auto it = j.find(kKeyShapeRange); if (it != j.end()) { @@ -189,9 +191,12 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { string type_str = j.at(kKeyType).get(); desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str); it = j.find(kKeyOriginFormat); if (it != j.end()) { string origin_format_str = j.at(kKeyOriginFormat).get(); + is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str); desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); } auto tensor_name = j.find(kKeyName); @@ -202,6 +207,9 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { if (dynamic_input_name != j.end()) { desc.dynamic_input_name = dynamic_input_name->get(); } + if (!is_tensor_valid) { + desc.SetValidFlag(is_tensor_valid); + } } void from_json(const Json &j, SingleOpAttr &attr) { @@ -305,6 +313,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { int index = 0; for (auto &tensor_desc : op_desc.input_desc) { + if (!tensor_desc.GetValidFlag()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"intput", "datatype or format", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); + return false; + } if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, @@ -317,6 +331,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { index = 0; for (auto &tensor_desc : op_desc.output_desc) { + if (!tensor_desc.GetValidFlag()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, + {"output", "datatype", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); + return false; + } if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, {"output", "datatype", std::to_string(index)}); diff --git a/ge/offline/single_op_parser.h b/ge/offline/single_op_parser.h index 19879a32..71aa58bb 100644 --- a/ge/offline/single_op_parser.h +++ b/ge/offline/single_op_parser.h @@ -28,6 +28,10 @@ namespace ge { struct SingleOpTensorDesc { +public: + bool GetValidFlag() const { return is_valid_; } + void SetValidFlag(bool is_valid) { is_valid_ = is_valid; } +public: std::string name; std::vector dims; std::vector ori_dims; @@ -36,6 +40,8 @@ struct SingleOpTensorDesc { ge::Format ori_format = ge::FORMAT_RESERVED; ge::DataType type = ge::DT_UNDEFINED; std::string dynamic_input_name; +private: + bool is_valid_ = true; }; struct SingleOpAttr { diff --git a/ge/omm/csa_interact.cc b/ge/omm/csa_interact.cc index 1599af94..1b33ddbd 100644 --- a/ge/omm/csa_interact.cc +++ b/ge/omm/csa_interact.cc @@ -202,7 +202,7 @@ Status CsaInteract::WriteFile(const std::string &file_name, const std::string &c } } - mmSsize_t ret = mmWrite(fd, (void *)content.c_str(), content.length()); + mmSsize_t ret = mmWrite(fd, reinterpret_cast(const_cast(content.c_str())), content.length()); if (ret == EN_ERROR) { GELOGE(INTERNAL_ERROR, "write file fail, errno is %d", errno); ret = mmClose(fd); diff --git a/ge/opskernel_manager/ops_kernel_builder_manager.cc b/ge/opskernel_manager/ops_kernel_builder_manager.cc index e0001fcd..37bdcf7a 100644 --- a/ge/opskernel_manager/ops_kernel_builder_manager.cc +++ b/ge/opskernel_manager/ops_kernel_builder_manager.cc @@ -167,4 +167,5 @@ Status OpsKernelBuilderManager::GenerateTask(const Node &node, GELOGD("Done invoking GenerateTask successfully"); return SUCCESS; } -} // namespace ge \ No newline at end of file + +} // namespace ge diff --git a/ge/opskernel_manager/ops_kernel_manager.cc b/ge/opskernel_manager/ops_kernel_manager.cc index 8134a463..30f39c0d 100644 --- a/ge/opskernel_manager/ops_kernel_manager.cc +++ b/ge/opskernel_manager/ops_kernel_manager.cc @@ -175,8 +175,8 @@ Status OpsKernelManager::ParsePluginOptions(const map &options, } else if (flag == 1) { enable_flag = true; } else { - GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(), - iter->second.c_str()); + GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", + plugin_name.c_str(), iter->second.c_str()); return GE_GRAPH_OPTIONS_INVALID; } } catch (std::invalid_argument &) { @@ -188,8 +188,8 @@ Status OpsKernelManager::ParsePluginOptions(const map &options, iter->second.c_str()); return GE_GRAPH_OPTIONS_INVALID; } catch (...) { - GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(), - iter->second.c_str()); + GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", + plugin_name.c_str(), iter->second.c_str()); return GE_GRAPH_OPTIONS_INVALID; } } else { diff --git a/ge/plugin/engine/dnnengines.cc b/ge/plugin/engine/dnnengines.cc old mode 100755 new mode 100644 diff --git a/ge/plugin/engine/module.mk b/ge/plugin/engine/module.mk old mode 100755 new mode 100644 diff --git a/ge/proto/fusion_model.proto b/ge/proto/fusion_model.proto old mode 100755 new mode 100644 diff --git a/ge/proto/ge_api.proto b/ge/proto/ge_api.proto old mode 100755 new mode 100644 diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc old mode 100755 new mode 100644 diff --git a/ge/session/omg.cc b/ge/session/omg.cc old mode 100755 new mode 100644 index df837f99..7ff52e82 --- a/ge/session/omg.cc +++ b/ge/session/omg.cc @@ -68,6 +68,9 @@ const std::string kScopeIdAttr = "fusion_scope"; const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; +const size_t kNodeNameIndex = 0; +const size_t kIndexStrIndex = 1; +const size_t kDTValueIndex = 2; } // namespace // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored @@ -381,14 +384,14 @@ Status ParseOutputType(const std::string &output_type, std::map(model.model_data); model.model_data = nullptr; } return status; @@ -902,7 +906,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, con if (status != ge::GRAPH_SUCCESS) { GELOGE(ge::FAILED, "Get model part failed."); if (model.model_data != nullptr) { - delete[](char *) model.model_data; + delete[] reinterpret_cast(model.model_data); model.model_data = nullptr; } return status; @@ -928,7 +932,7 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertOmModelToJson(const char *model_file, con } if (model.model_data != nullptr) { - delete[](char *) model.model_data; + delete[] reinterpret_cast(model.model_data); model.model_data = nullptr; } return ret; diff --git a/ge/session/session_manager.cc b/ge/session/session_manager.cc old mode 100755 new mode 100644 diff --git a/ge/single_op/single_op.cc b/ge/single_op/single_op.cc old mode 100755 new mode 100644 index 371d7110..a2652b67 --- a/ge/single_op/single_op.cc +++ b/ge/single_op/single_op.cc @@ -17,6 +17,7 @@ #include "single_op/single_op.h" #include "common/fmk_types.h" +#include "common/ge_types.h" #include "common/math/math_util.h" #include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" @@ -24,19 +25,60 @@ #include "graph/load/new_model_manager/model_utils.h" #include "runtime/mem.h" #include "single_op/single_op_manager.h" +#include "single_op/task/build_task_utils.h" #include "graph/load/new_model_manager/model_manager.h" namespace ge { namespace { const size_t kDataMemAlignSize = 32; +const size_t kDataMemAlignUnit = 2; size_t GetAlignedSize(size_t size) { - size_t aligned_size = (size + 2 * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize; + size_t aligned_size = (size + kDataMemAlignUnit * kDataMemAlignSize - 1) / kDataMemAlignSize * kDataMemAlignSize; return aligned_size; } + +Status ProfilingTaskInfo(OpTask *op_task) { + if (!ProfilingManager::Instance().ProfilingModelExecuteOn()) { + return SUCCESS; + } + + string model_name; + string op_name; + uint32_t model_id; + uint32_t block_dim; + if (op_task->GetProfilingArgs(model_name, op_name, model_id, block_dim) != SUCCESS) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Get profiling data of task failed"); + return ACL_ERROR_GE_PARAM_INVALID; + } + GELOGD("ProfilingReport of op[%s] model[%s] start.", op_name.c_str(), model_name.c_str()); + std::vector task_desc_info; + uint32_t task_id = 0; + uint32_t stream_id = 0; + if (rtGetTaskIdAndStreamID(&task_id, &stream_id) != RT_ERROR_NONE) { + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Get task_id and stream_id failed."); + return ACL_ERROR_GE_PARAM_INVALID; + } + + TaskDescInfo tmp_task_desc_info; + tmp_task_desc_info.model_name = model_name; + tmp_task_desc_info.op_name = op_name; + tmp_task_desc_info.block_dim = block_dim; + tmp_task_desc_info.task_id = task_id; + tmp_task_desc_info.stream_id = stream_id; + GELOGD("GetTaskDescInfo of op [%s] end, task_id[%u], stream_id[%u]", op_name.c_str(), task_id, stream_id); + task_desc_info.emplace_back(tmp_task_desc_info); + + std::vector compute_graph_info; + + auto &profiling_manager = ProfilingManager::Instance(); + profiling_manager.ReportProfilingData(model_id, task_desc_info, compute_graph_info); + return SUCCESS; +} } // namespace -SingleOp::SingleOp(std::mutex *stream_mutex, rtStream_t stream) : stream_mutex_(stream_mutex), stream_(stream) { +SingleOp::SingleOp(StreamResource *stream_resource, std::mutex *stream_mutex, rtStream_t stream) + : stream_resource_(stream_resource), stream_mutex_(stream_mutex), stream_(stream) { } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOp::~SingleOp() { @@ -68,7 +110,8 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: auto num_outputs = outputs.size(); if (num_outputs != output_sizes_.size()) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu", output_sizes_.size(), outputs.size()); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu", + output_sizes_.size(), outputs.size()); return ACL_ERROR_GE_PARAM_INVALID; } @@ -117,37 +160,6 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve *arg_addr = args_[i]; } } - // update aicpu_TF or aicpu_CC args - for (auto &task : tasks_) { - size_t io_addr_num = args_.size(); - if (task->GetOpTaskType() == OP_TASK_AICPU) { - GELOGD("Update aicpu_TF task args"); - task->SetIoAddrsForDump(args_); - auto *dst_io_addr = const_cast(reinterpret_cast(task->GetIOAddr())); - GE_CHECK_NOTNULL(dst_io_addr); - auto rt_ret = rtMemcpyAsync(dst_io_addr, - sizeof(uint64_t) * args_.size(), - &args_[0], - sizeof(uint64_t) * args_.size(), - RT_MEMCPY_HOST_TO_DEVICE_EX, - stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); - return rt_ret; - } - } else if (task->GetOpTaskType() == OP_TASK_AICPUCC) { - GELOGD("Update aicpu_CC task args"); - const uintptr_t *task_io_addr = reinterpret_cast(task->GetIOAddr()); - GE_CHECK_NOTNULL(task_io_addr); - auto io_addr = reinterpret_cast(const_cast(task_io_addr)); - for (size_t i = 0; i < io_addr_num; ++i) { - io_addr[i] = static_cast(args_[i]); - } - } else { - GELOGW("Only TF_kernel aicpu and aicpu_CC are supported, but got %u", task->GetOpTaskType()); - continue; - } - } return SUCCESS; } @@ -158,7 +170,19 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c return ret; } + GE_CHECK_NOTNULL(stream_resource_); std::lock_guard lk(*stream_mutex_); + auto current_mem_base = stream_resource_->GetMemoryBase(); + if (running_param_->mem_base != current_mem_base) { + running_param_->mem_base = const_cast(current_mem_base); + GELOGD("Memory base changed, new memory base = %p", current_mem_base); + for (auto &task : tasks_) { + auto new_address = BuildTaskUtils::GetAddresses(task->GetOpdesc(), *running_param_); + GE_CHK_STATUS_RET(task->UpdateArgTable(*running_param_), + "[%s] Failed to update arg table", + task->GetOpdesc()->GetName().c_str()); + } + } ret = UpdateArgs(inputs, outputs); if (ret != SUCCESS) { return ret; @@ -169,6 +193,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c if (ret != SUCCESS) { return ret; } + GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(task)); } return ret; @@ -182,9 +207,6 @@ DynamicSingleOp::DynamicSingleOp(uintptr_t resource_id, std::mutex *stream_mutex : resource_id_(resource_id), stream_mutex_(stream_mutex), stream_(stream) { } -DynamicSingleOp::~DynamicSingleOp() { -} - Status DynamicSingleOp::ValidateParams(const vector &input_desc, const std::vector &inputs, std::vector &output_desc, @@ -206,63 +228,24 @@ Status DynamicSingleOp::ValidateParams(const vector &input_desc, } if (input_desc.size() != num_inputs_) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Input number mismatches. expect %zu, but given %zu", num_inputs_, input_desc.size()); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "Input number mismatches. expect %zu, but given %zu", + num_inputs_, + input_desc.size()); return ACL_ERROR_GE_PARAM_INVALID; } if (output_desc.size() != num_outputs_) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Output number mismatches. expect %zu, but given %zu", num_outputs_, output_desc.size()); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, + "Output number mismatches. expect %zu, but given %zu", + num_outputs_, + output_desc.size()); return ACL_ERROR_GE_PARAM_INVALID; } return SUCCESS; } -Status DynamicSingleOp::AllocateWorkspaces(const std::vector &workspace_sizes, - std::vector &workspaces) { - static const std::string kPurpose("malloc workspace memory for dynamic op."); - if (workspace_sizes.empty()) { - GELOGD("No need to allocate workspace."); - return SUCCESS; - } - int64_t total_size = 0; - std::vector ws_offsets; - for (auto ws_size : workspace_sizes) { - // alignment and padding should be done in OpParaCalculate - GE_CHK_STATUS_RET_NOLOG(CheckInt64AddOverflow(total_size, ws_size)); - ws_offsets.emplace_back(total_size); - total_size += ws_size; - } - - GELOGD("Total workspace size is %ld", total_size); - StreamResource *stream_resource = SingleOpManager::GetInstance().GetResource(resource_id_, stream_); - GE_CHECK_NOTNULL(stream_resource); - auto ws_base = stream_resource->MallocMemory(kPurpose, static_cast(total_size)); - if (ws_base == nullptr) { - GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to allocate memory of size: %ld", total_size); - return ACL_ERROR_GE_MEMORY_ALLOCATION; - } - GELOGD("Done allocating workspace memory successfully."); - - for (auto ws_offset : ws_offsets) { - workspaces.emplace_back(ws_base + ws_offset); - } - - return SUCCESS; -} - -Status DynamicSingleOp::ExecuteTbeTask(const vector &input_desc, - const vector &inputs, - vector &output_desc, - vector &outputs) { - GE_CHK_STATUS_RET_NOLOG(op_task_->UpdateRunInfo(input_desc, output_desc)); - - std::vector workspace_buffers; - GE_CHK_STATUS_RET_NOLOG(AllocateWorkspaces(op_task_->GetWorkspaceSizes(), workspace_buffers)); - - return op_task_->LaunchKernel(inputs, outputs, workspace_buffers, stream_); -} - Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, const vector &input_buffers, vector &output_desc, @@ -271,24 +254,8 @@ Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, GE_CHK_STATUS_RET_NOLOG(ValidateParams(input_desc, input_buffers, output_desc, output_buffers)); std::lock_guard lk(*stream_mutex_); - std::vector inputs; - std::vector outputs; - for (auto &buffer : input_buffers) { - inputs.emplace_back(buffer.data); - } - for (auto &buffer : output_buffers) { - outputs.emplace_back(buffer.data); - } - - if (op_task_->GetOpTaskType() == OP_TASK_TBE) { - return ExecuteTbeTask(input_desc, inputs, output_desc, outputs); - } else if (op_task_->GetOpTaskType() == OP_TASK_AICPU || op_task_->GetOpTaskType() == OP_TASK_AICPUCC) { - return op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_); - } else { - GELOGE(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, - "Only TBE_Task, AI_CPU_Task and AI_CPUCC_Task are supported, but got %u", - op_task_->GetOpTaskType()); - return ACL_ERROR_GE_OP_TASK_TYPE_INVALID; - } + GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); + GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get())); + return SUCCESS; } } // namespace ge diff --git a/ge/single_op/single_op.h b/ge/single_op/single_op.h old mode 100755 new mode 100644 index 14ef8ce1..d677f94a --- a/ge/single_op/single_op.h +++ b/ge/single_op/single_op.h @@ -30,9 +30,11 @@ #include "cce/aicpu_engine_struct.h" namespace ge { +class StreamResource; +struct SingleOpModelParam; class SingleOp { public: - SingleOp(std::mutex *stream_mutex, rtStream_t stream); + SingleOp(StreamResource *stream_resource, std::mutex *stream_mutex, rtStream_t stream); ~SingleOp(); Status ExecuteAsync(const std::vector &inputs, const std::vector &outputs); @@ -44,6 +46,7 @@ class SingleOp { Status GetArgs(const std::vector &inputs, const std::vector &outputs); friend class SingleOpModel; + StreamResource *stream_resource_; std::mutex *stream_mutex_; rtStream_t stream_ = nullptr; std::vector input_addr_list_; @@ -54,12 +57,13 @@ class SingleOp { std::vector tasks_; std::vector> arg_table_; + std::unique_ptr running_param_; }; class DynamicSingleOp { public: DynamicSingleOp(uintptr_t resource_id, std::mutex *stream_mutex_, rtStream_t stream); - ~DynamicSingleOp(); + ~DynamicSingleOp() = default; Status ExecuteAsync(const vector &input_desc, const std::vector &inputs, std::vector &output_desc, @@ -72,14 +76,6 @@ class DynamicSingleOp { std::vector &output_desc, std::vector &outputs) const; - Status AllocateWorkspaces(const std::vector &workspace_sizes, - std::vector &workspaces); - - Status ExecuteTbeTask(const vector &input_desc, - const vector &inputs, - vector &output_desc, - vector &outputs); - std::unique_ptr op_task_; uintptr_t resource_id_ = 0; std::mutex *stream_mutex_; diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc old mode 100755 new mode 100644 index 49968f4f..25bf6855 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -92,7 +92,8 @@ Status SingleOpModel::InitModelMem(StreamResource &res) { if (model_params_.memory_size > model_params_.zero_copy_mem_size) { const string purpose("malloc feature map memory on model execute."); GELOGI("total memory: %lu, zero_copy_mem: %lu", model_params_.memory_size, model_params_.zero_copy_mem_size); - model_params_.mem_base = res.MallocMemory(purpose, model_params_.memory_size - model_params_.zero_copy_mem_size); + model_params_.mem_base = + res.MallocMemory(purpose, model_params_.memory_size - model_params_.zero_copy_mem_size, false); if (model_params_.mem_base == nullptr) { return ACL_ERROR_GE_MEMORY_ALLOCATION; } @@ -157,6 +158,7 @@ Status SingleOpModel::LoadAllNodes() { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); Graph graph = ge_model->GetGraph(); + model_id_ = ge_model->GetModelId(); auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph == nullptr) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "[%s] compute_graph is null", model_name_.c_str()); @@ -225,9 +227,10 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { return SUCCESS; } -Status SingleOpModel::BuildTaskList(SingleOp &single_op) { +Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &single_op) { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); + single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); auto tasks = ge_model->GetModelTaskDefPtr()->task(); for (int i = 0; i < tasks.size(); ++i) { const TaskDef &task_def = tasks[i]; @@ -237,8 +240,8 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { if (task_type == RT_MODEL_TASK_KERNEL) { const domi::KernelDef &kernel_def = task_def.kernel(); const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == cce::ccKernelType::TE) { + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); TbeOpTask *tbe_task = nullptr; auto ret = BuildKernelTask(task_def.kernel(), &tbe_task); @@ -246,10 +249,13 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { return ret; } - single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); ParseArgTable(tbe_task, single_op); + tbe_task->SetModelArgs(model_name_, model_id_); + if (tbe_task->tiling_buffer_ != nullptr) { + tbe_task->stream_resource_ = stream_resource; + } single_op.tasks_.emplace_back(tbe_task); - } else if (kernel_type == cce::ccKernelType::AI_CPU || kernel_type == cce::ccKernelType::CUST_AI_CPU) { + } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { GELOGD("Building AICPU_CC task"); OpTask *task = nullptr; uint64_t singleop_kernel_id = aicpu_kernel_id++; @@ -258,9 +264,12 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { if (ret != SUCCESS) { return ret; } + task->SetModelArgs(model_name_, model_id_); + ParseArgTable(task, single_op); single_op.tasks_.emplace_back(task); } else { - GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); + GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, + "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID; } } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { @@ -273,6 +282,8 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { if (ret != SUCCESS) { return ret; } + aicpu_task->SetModelArgs(model_name_, model_id_); + ParseArgTable(aicpu_task, single_op); single_op.tasks_.emplace_back(aicpu_task); } else { // skip @@ -282,21 +293,23 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { return SUCCESS; } -void SingleOpModel::ParseArgTable(TbeOpTask *task, SingleOp &op) { +void SingleOpModel::ParseArgTable(OpTask *task, SingleOp &op) { if (task == nullptr) { GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "tbe op task is nullptr"); return; } + // args: addr1, addr2, addr3 ... - auto *args = const_cast(reinterpret_cast(task->GetArgs())); - size_t arg_size = task->GetArgSize(); - for (size_t i = 0; i < arg_size / sizeof(void *); ++i) { - uintptr_t *ptr_to_addr = args + i; + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + task->GetIoAddr(arg_base, arg_num); + for (size_t i = 0; i < arg_num; ++i) { + uintptr_t *ptr_to_addr = arg_base + i; uintptr_t addr = *ptr_to_addr; auto iter = model_params_.addr_mapping_.find(addr); if (iter != model_params_.addr_mapping_.end()) { int arg_index = iter->second; - GELOGI("%s args[%zu] mapped to user designated args[%d]", task->GetStubName().c_str(), i, arg_index); + GELOGI("%s args[%zu] mapped to user designated args[%d]", task->GetOpdesc()->GetName().c_str(), i, arg_index); op.arg_table_[iter->second].emplace_back(ptr_to_addr); } } @@ -368,7 +381,7 @@ Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTa } auto builder = AiCpuCCTaskBuilder(iter->second->GetOpDesc(), kernel_def); - auto ret = builder.BuildTask(*aicpucc_task, kernel_id); + auto ret = builder.BuildTask(*aicpucc_task, kernel_id, model_params_); if (ret != SUCCESS) { GELOGE(ret, "build aicpu_CC op task failed"); return ret; @@ -381,25 +394,29 @@ Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTa Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { GE_CHK_STATUS_RET_NOLOG(ParseInputsAndOutputs()); GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); + single_op.running_param_.reset(new (std::nothrow)SingleOpModelParam(model_params_)); + GE_CHECK_NOTNULL(single_op.running_param_); GE_CHK_STATUS_RET_NOLOG(SetInputsAndOutputs(single_op)); - return BuildTaskList(single_op); + return BuildTaskList(&resource, single_op); } Status SingleOpModel::BuildModelTaskKernel(const TaskDef &task_def, DynamicSingleOp &single_op) { const domi::KernelDef &kernel_def = task_def.kernel(); const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == cce::ccKernelType::TE) { + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::TE) { GELOGD("Building TBE task"); TbeOpTask *tbe_task = nullptr; GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def.kernel(), &tbe_task)); + tbe_task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(tbe_task); - } else if (kernel_type == cce::ccKernelType::AI_CPU || kernel_type == cce::ccKernelType::CUST_AI_CPU) { + } else if (kernel_type == ccKernelType::AI_CPU || kernel_type == ccKernelType::CUST_AI_CPU) { GELOGD("Building AICPU_CC task"); OpTask *task = nullptr; uint64_t dynamic_singleop_kernel_id = aicpu_kernel_id++; GELOGI("Build dynamic singleOp CCTask, kernel_id = %lu", dynamic_singleop_kernel_id); GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task, dynamic_singleop_kernel_id)); + task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(task); } else { GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, @@ -446,6 +463,7 @@ Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { const TaskDef ©_task_def = tasks[i]; GE_CHK_STATUS_RET_NOLOG(aicpu_task->SetMemCopyTask(copy_task_def.kernel_ex())); } + aicpu_task->SetModelArgs(model_name_, model_id_); single_op.op_task_.reset(aicpu_task); } else { // skip @@ -455,10 +473,10 @@ Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { return SUCCESS; } -Status SingleOpModel::BuildDynamicOp(DynamicSingleOp &single_op) { +Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &single_op) { single_op.num_inputs_ = data_ops_.size(); single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); - ParseOpModelParams(model_helper_, model_params_); + GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); return BuildTaskListForDynamicOp(single_op); } } // namespace ge diff --git a/ge/single_op/single_op_model.h b/ge/single_op/single_op_model.h old mode 100755 new mode 100644 index 50aeb7ab..6d0109fe --- a/ge/single_op/single_op_model.h +++ b/ge/single_op/single_op_model.h @@ -52,7 +52,7 @@ class SingleOpModel { Status Init(); Status BuildOp(StreamResource &resource, SingleOp &single_op); - Status BuildDynamicOp(DynamicSingleOp &single_op); + Status BuildDynamicOp(StreamResource &resource, DynamicSingleOp &single_op); private: Status InitModel(); @@ -65,7 +65,7 @@ class SingleOpModel { Status ParseInputNode(const OpDescPtr &op_desc); void ParseOutputNode(const OpDescPtr &op_desc); - Status BuildTaskList(SingleOp &single_op); + Status BuildTaskList(StreamResource *stream_resource, SingleOp &single_op); Status BuildTaskListForDynamicOp(DynamicSingleOp &dynamic_single_op); Status BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, AiCpuTask **task, @@ -74,9 +74,10 @@ class SingleOpModel { Status BuildModelTaskKernel(const domi::TaskDef &task_def, DynamicSingleOp &single_op); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); - void ParseArgTable(TbeOpTask *task, SingleOp &op); + void ParseArgTable(OpTask *task, SingleOp &op); std::string model_name_; + uint32_t model_id_ = 0; const void *ori_model_data_; uint32_t ori_model_size_; diff --git a/ge/single_op/stream_resource.cc b/ge/single_op/stream_resource.cc old mode 100755 new mode 100644 index f545b6c8..db6b7c47 --- a/ge/single_op/stream_resource.cc +++ b/ge/single_op/stream_resource.cc @@ -69,11 +69,25 @@ uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, std::vector &allocated) { + if (size == 0) { + GELOGD("Mem size == 0"); + return nullptr; + } + if (size <= max_allocated && !allocated.empty()) { GELOGD("reuse last memory"); return allocated.back(); } + if (!allocated.empty()) { + uint8_t *current_buffer = allocated.back(); + allocated.pop_back(); + if (rtStreamSynchronize(stream_) != RT_ERROR_NONE) { + GELOGW("Failed to invoke rtStreamSynchronize"); + } + (void) rtFree(current_buffer); + } + uint8_t *buffer = nullptr; auto ret = rtMalloc(reinterpret_cast(&buffer), size, RT_MEMORY_HBM); if (ret != RT_ERROR_NONE) { @@ -96,10 +110,14 @@ uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, return buffer; } -uint8_t *StreamResource::MallocMemory(const std::string &purpose, size_t size) { +uint8_t *StreamResource::MallocMemory(const std::string &purpose, size_t size, bool holding_lock) { GELOGD("To Malloc memory, size = %zu", size); - uint8_t *buffer = DoMallocMemory(purpose, size, max_memory_size_, memory_list_); - return buffer; + if (holding_lock) { + return DoMallocMemory(purpose, size, max_memory_size_, memory_list_); + } else { + std::lock_guard lk(stream_mu_); + return DoMallocMemory(purpose, size, max_memory_size_, memory_list_); + } } uint8_t *StreamResource::MallocWeight(const std::string &purpose, size_t size) { @@ -137,7 +155,8 @@ Status StreamResource::BuildDynamicOperator(const string &model_name, GE_CHECK_NOTNULL(new_op); GELOGI("To build operator: %s", model_name.c_str()); - GE_CHK_STATUS_RET(model.BuildDynamicOp(*new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); + GE_CHK_STATUS_RET(model.BuildDynamicOp(*this, *new_op), + "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); *single_op = new_op.get(); dynamic_op_map_[model_data.model_data] = std::move(new_op); return SUCCESS; @@ -158,7 +177,7 @@ Status StreamResource::BuildOperator(const string &model_name, const ModelData & return ret; } - auto new_op = std::unique_ptr(new(std::nothrow) SingleOp(&stream_mu_, stream_)); + auto new_op = std::unique_ptr(new(std::nothrow) SingleOp(this, &stream_mu_, stream_)); if (new_op == nullptr) { GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "new SingleOp failed"); return ACL_ERROR_GE_MEMORY_ALLOCATION; @@ -171,4 +190,12 @@ Status StreamResource::BuildOperator(const string &model_name, const ModelData & op_map_[model_data.model_data] = std::move(new_op); return SUCCESS; } + +const uint8_t *StreamResource::GetMemoryBase() const { + if (memory_list_.empty()) { + return nullptr; + } + + return memory_list_.back(); +} } // namespace ge diff --git a/ge/single_op/stream_resource.h b/ge/single_op/stream_resource.h old mode 100755 new mode 100644 index 39f08ebe..d5bc941a --- a/ge/single_op/stream_resource.h +++ b/ge/single_op/stream_resource.h @@ -45,8 +45,9 @@ class StreamResource { Status BuildOperator(const std::string &model_name, const ModelData &model_data, SingleOp **single_op); Status BuildDynamicOperator(const std::string &model_name, const ModelData &model_data, DynamicSingleOp **single_op); - uint8_t *MallocMemory(const std::string &purpose, size_t size); + uint8_t *MallocMemory(const std::string &purpose, size_t size, bool holding_lock = true); uint8_t *MallocWeight(const std::string &purpose, size_t size); + const uint8_t *GetMemoryBase() const; private: uint8_t *DoMallocMemory(const std::string &purpose, diff --git a/ge/single_op/task/aicpu_kernel_task_builder.cc b/ge/single_op/task/aicpu_kernel_task_builder.cc old mode 100755 new mode 100644 index 26f6a166..34f1ba7b --- a/ge/single_op/task/aicpu_kernel_task_builder.cc +++ b/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -15,19 +15,24 @@ */ #include "single_op/task/aicpu_kernel_task_builder.h" -#include "cce/taskdown_common.hpp" +#include "framework/common/taskdown_common.h" #include "graph/load/new_model_manager/model_manager.h" +#include "build_task_utils.h" namespace ge { AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) : op_desc_(op_desc), kernel_def_(kernel_def) {} -Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { +Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task, const SingleOpModelParam ¶m) { size_t aicpu_arg_size = kernel_def_.args_size(); - if (aicpu_arg_size <= 0) { + if (aicpu_arg_size <= sizeof(aicpu::AicpuParamHead)) { GELOGE(ACL_ERROR_GE_PARAM_INVALID, "aicpu_arg_size is invalid, value = %zu", aicpu_arg_size); return ACL_ERROR_GE_PARAM_INVALID; } + + task.io_addr_num_ = op_desc_->GetInputsSize() + op_desc_->GetOutputsSize(); + GE_CHECK_GE(aicpu_arg_size - sizeof(aicpu::AicpuParamHead), task.io_addr_num_ * sizeof(void *)); + std::unique_ptr aicpu_args; aicpu_args.reset(new(std::nothrow) uint8_t[aicpu_arg_size]()); if (aicpu_args == nullptr) { @@ -41,13 +46,19 @@ Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { return ACL_ERROR_GE_INTERNAL_ERROR; } - task.SetIoAddr(aicpu_args.get() + sizeof(aicpu::AicpuParamHead)); + task.SetIoAddr(reinterpret_cast(aicpu_args.get() + sizeof(aicpu::AicpuParamHead))); task.SetKernelArgs(std::move(aicpu_args), aicpu_arg_size); + + auto addresses = BuildTaskUtils::GetKernelArgs(op_desc_, param); + GE_CHECK_GE(addresses.size(), task.io_addr_num_); + for (size_t i = 0; i < task.io_addr_num_; ++i) { + task.io_addr_[i] = reinterpret_cast(addresses[i]); + } return SUCCESS; } -Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id) { - auto ret = SetKernelArgs(task); +Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id, const SingleOpModelParam ¶m) { + auto ret = SetKernelArgs(task, param); if (ret != SUCCESS) { return ret; } @@ -55,15 +66,20 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id) { const std::string &kernel_name = kernel_def_.kernel_name(); task.SetSoName(so_name); task.SetkernelName(kernel_name); + GE_CHECK_NOTNULL(op_desc_); task.op_desc_ = op_desc_; const auto &context = kernel_def_.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::CUST_AI_CPU) { task.is_custom_ = true; task.dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; - GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc_, so_name), "launch cust aicpu so failed"); - GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "launch cust aicpu so failed."); + bool loaded = false; + GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc_, so_name, loaded), + "launch cust aicpu so failed"); + if (!loaded) { + GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "launch cust aicpu so failed."); + } } task.num_inputs_ = op_desc_->GetInputsSize(); @@ -81,7 +97,12 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id) { GELOGE(ret, "Init ext info failed."); return ret; } + GE_CHK_STATUS_RET(task.SetInputConst(), "AiCpuCCTask set input_const failed."); + if (task.GetUnknownType() == DEPEND_COMPUTE) { + GELOGE(FAILED, "AiCpuCCTask unknown type is depend compute, it's not supported now."); + return FAILED; + } auto aicpu_param_head = reinterpret_cast(task.args_.get()); if (task.ext_info_addr_dev_ != nullptr) { aicpu_param_head->extInfoLength = kernel_ext_info.size(); diff --git a/ge/single_op/task/aicpu_kernel_task_builder.h b/ge/single_op/task/aicpu_kernel_task_builder.h old mode 100755 new mode 100644 index e77e3c10..85d5034d --- a/ge/single_op/task/aicpu_kernel_task_builder.h +++ b/ge/single_op/task/aicpu_kernel_task_builder.h @@ -30,10 +30,10 @@ class AiCpuCCTaskBuilder { explicit AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); ~AiCpuCCTaskBuilder() = default; - Status BuildTask(AiCpuCCTask &task, uint64_t kernel_id); + Status BuildTask(AiCpuCCTask &task, uint64_t kernel_id, const SingleOpModelParam ¶m); private: - Status SetKernelArgs(AiCpuCCTask &task); + Status SetKernelArgs(AiCpuCCTask &task, const SingleOpModelParam ¶m); const OpDescPtr op_desc_; const domi::KernelDef &kernel_def_; }; diff --git a/ge/single_op/task/aicpu_task_builder.cc b/ge/single_op/task/aicpu_task_builder.cc old mode 100755 new mode 100644 index 8f28ffda..5fd4879e --- a/ge/single_op/task/aicpu_task_builder.cc +++ b/ge/single_op/task/aicpu_task_builder.cc @@ -26,26 +26,6 @@ namespace ge { AiCpuTaskBuilder::AiCpuTaskBuilder(const OpDescPtr &op_desc, const domi::KernelExDef &kernel_def) : op_desc_(op_desc), kernel_def_(kernel_def) {} - Status AiCpuTaskBuilder::SetInputOutputAddr(void **io_addr, const std::vector &addresses) { - size_t arg_size = kernel_def_.args_size(); - auto rt_ret = rtMalloc(io_addr, arg_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtMalloc failed, size = %zu, ret = %d", arg_size, rt_ret); - return rt_ret; - } - - const void *src_addr = reinterpret_cast(addresses.data()); - uint64_t src_len = sizeof(void *) * addresses.size(); - rt_ret = rtMemcpy(*io_addr, arg_size, src_addr, src_len, RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - (void)rtFree(*io_addr); - GELOGE(rt_ret, "rtMemcpy addresses failed, ret = %d", rt_ret); - return rt_ret; - } - - return SUCCESS; - } - Status AiCpuTaskBuilder::SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &fwk_op_kernel) { auto sec_ret = memcpy_s(&fwk_op_kernel, sizeof(STR_FWK_OP_KERNEL), kernel_def_.args().data(), kernel_def_.args().size()); @@ -80,39 +60,27 @@ namespace ge { return SUCCESS; } - Status AiCpuTaskBuilder::InitWorkspaceAndIO(void **io_addr, void **kernel_workspace, - const SingleOpModelParam ¶m, bool dynamic_flag) { + Status AiCpuTaskBuilder::InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag) { if (kernel_def_.args_size() > sizeof(STR_FWK_OP_KERNEL)) { GELOGE(ACL_ERROR_GE_PARAM_INVALID, "sizeof STR_FWK_OP_KERNEL is: %lu, but args_size is: %d", sizeof(STR_FWK_OP_KERNEL), kernel_def_.args_size()); return ACL_ERROR_GE_PARAM_INVALID; } - auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param); - auto ws_addr_vec = addresses.at(BuildTaskUtils::kAddressIndexWorkspace); - - if (dynamic_flag) { - GE_CHK_RT_RET(rtMalloc(kernel_workspace, kernel_def_.task_info_size(), RT_MEMORY_HBM)); - } else { - if (ws_addr_vec.empty()) { - GELOGE(ACL_ERROR_GE_PARAM_INVALID, "workspace Data Address is empty."); - return ACL_ERROR_GE_PARAM_INVALID; - } - *kernel_workspace = ws_addr_vec[0]; - } - GE_CHK_RT_RET(rtMemcpy(*kernel_workspace, kernel_def_.task_info_size(), + GE_CHK_RT_RET(rtMalloc(&task.workspace_addr_, kernel_def_.task_info_size(), RT_MEMORY_HBM)); + GE_CHK_RT_RET(rtMemcpy(task.workspace_addr_, kernel_def_.task_info_size(), kernel_def_.task_info().data(), kernel_def_.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE)); - auto ret = SetInputOutputAddr(io_addr, BuildTaskUtils::JoinAddresses(addresses)); - if (ret != SUCCESS) { - return ret; - } + auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, false); + task.io_addr_host_ = BuildTaskUtils::JoinAddresses(addresses); + task.io_addr_size_ = task.io_addr_host_.size() * sizeof(void *); + GE_CHK_RT_RET(rtMalloc(&task.io_addr_, task.io_addr_size_, RT_MEMORY_HBM)); return SUCCESS; } Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag, uint64_t kernel_id) { - GE_CHK_STATUS_RET_NOLOG(InitWorkspaceAndIO(&task.io_addr_, &task.workspace_addr_, param, dynamic_flag)); + GE_CHK_STATUS_RET_NOLOG(InitWorkspaceAndIO(task, param, dynamic_flag)); STR_FWK_OP_KERNEL fwk_op_kernel = {0}; auto ret = SetFmkOpKernel(task.io_addr_, task.workspace_addr_, fwk_op_kernel); @@ -120,6 +88,7 @@ namespace ge { return ret; } + GE_CHECK_NOTNULL(op_desc_); task.op_desc_ = op_desc_; task.num_inputs_ = op_desc_->GetInputsSize(); task.num_outputs_ = op_desc_->GetOutputsSize(); @@ -136,6 +105,7 @@ namespace ge { fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(task.ext_info_addr_dev_); fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = kernel_ext_info_size; } + GE_CHK_STATUS_RET(task.SetInputConst(), "AiCpuTask set input_const failed."); GE_CHK_STATUS_RET(task.InitForSummaryAndCopy(), "AiCpuTask init for summary and copy task failed."); fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID = ULLONG_MAX; diff --git a/ge/single_op/task/aicpu_task_builder.h b/ge/single_op/task/aicpu_task_builder.h old mode 100755 new mode 100644 index 4669e118..fe9c9bc2 --- a/ge/single_op/task/aicpu_task_builder.h +++ b/ge/single_op/task/aicpu_task_builder.h @@ -33,10 +33,8 @@ namespace ge { private: static Status SetKernelArgs(void **args, STR_FWK_OP_KERNEL &kernel); - Status SetInputOutputAddr(void **io_addr, const std::vector &addresses); Status SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &kernel); - Status InitWorkspaceAndIO(void **io_addr, void **kernel_workspace, - const SingleOpModelParam ¶m, bool dynamic_flag); + Status InitWorkspaceAndIO(AiCpuTask &task, const SingleOpModelParam ¶m, bool dynamic_flag); const OpDescPtr op_desc_; const domi::KernelExDef &kernel_def_; diff --git a/ge/single_op/task/build_task_utils.cc b/ge/single_op/task/build_task_utils.cc index 29f1657b..071e514b 100644 --- a/ge/single_op/task/build_task_utils.cc +++ b/ge/single_op/task/build_task_utils.cc @@ -32,7 +32,8 @@ const uint64_t kVarSize = 0; } std::vector> BuildTaskUtils::GetAddresses(const OpDescPtr &op_desc, - const SingleOpModelParam ¶m) { + const SingleOpModelParam ¶m, + bool keep_workspace) { std::vector> ret; RuntimeParam runtime_para; runtime_para.mem_size = param.memory_size; @@ -49,7 +50,9 @@ std::vector> BuildTaskUtils::GetAddresses(const OpDescPtr &o ret.emplace_back(ModelUtils::GetInputDataAddrs(runtime_para, op_desc)); ret.emplace_back(ModelUtils::GetOutputDataAddrs(runtime_para, op_desc)); - ret.emplace_back(ModelUtils::GetWorkspaceDataAddrs(runtime_para, op_desc)); + if (keep_workspace) { + ret.emplace_back(ModelUtils::GetWorkspaceDataAddrs(runtime_para, op_desc)); + } return ret; } diff --git a/ge/single_op/task/build_task_utils.h b/ge/single_op/task/build_task_utils.h index cddc7a2b..7a2369e4 100644 --- a/ge/single_op/task/build_task_utils.h +++ b/ge/single_op/task/build_task_utils.h @@ -27,15 +27,17 @@ namespace ge { class BuildTaskUtils { public: + static constexpr int kAddressIndexOutput = 1; static constexpr int kAddressIndexWorkspace = 2; - static std::vector> GetAddresses(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); + static std::vector> GetAddresses(const OpDescPtr &op_desc, + const SingleOpModelParam ¶m, + bool keep_workspace = true); static std::vector JoinAddresses(const std::vector> &addresses); static std::vector GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); static std::string GetTaskInfo(const OpDescPtr &op_desc); template - static std::string VectorToString(const std::vector &values) - { + static std::string VectorToString(const std::vector &values) { std::stringstream ss; ss << '['; auto size = values.size(); diff --git a/ge/single_op/task/op_task.cc b/ge/single_op/task/op_task.cc old mode 100755 new mode 100644 index c3c4e5bb..4f64251c --- a/ge/single_op/task/op_task.cc +++ b/ge/single_op/task/op_task.cc @@ -24,9 +24,11 @@ #include "common/dump/dump_manager.h" #include "common/dump/dump_op.h" #include "common/formats/formats.h" +#include "common/math/math_util.h" #include "framework/common/debug/log.h" #include "register/op_tiling.h" #include "runtime/rt.h" +#include "build_task_utils.h" namespace ge { namespace { @@ -48,18 +50,22 @@ Status OpTask::OpenDump(rtStream_t stream) { std::vector output_adds; auto input_size = op_desc_->GetInputsSize(); auto output_size = op_desc_->GetOutputsSize(); - auto all_size = io_addrs_for_dump_.size(); - if (input_size + output_size != all_size) { - GELOGE(FAILED, "io_addrs_for_dump_ size %zu is not equal input and output size %zu", all_size, + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); + if (arg_num < input_size + output_size) { + GELOGE(FAILED, "io_addrs_for_dump_ size %zu is not equal input and output size %zu", + arg_num, input_size + output_size); return FAILED; } + for (size_t i = 0; i < input_size; i++) { - uint64_t input_addr = io_addrs_for_dump_[i]; + uint64_t input_addr = arg_base[i]; input_addrs.emplace_back(input_addr); } for (size_t j = 0; j < output_size; j++) { - uint64_t output_addr = io_addrs_for_dump_[input_size + j]; + uint64_t output_addr = arg_base[input_size + j]; output_adds.emplace_back(output_addr); } dump_op_.SetDumpInfo(DumpManager::GetInstance().GetDumpProperties(), op_desc_, input_addrs, output_adds, stream); @@ -89,9 +95,50 @@ void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size void TbeOpTask::SetSmDesc(void *sm_desc) { sm_desc_ = sm_desc; } -const vector &OpTask::GetWorkspaceSizes() const { return workspace_sizes_; } +void OpTask::SetModelArgs(std::string model_name, uint32_t model_id) { + model_name_ = model_name; + model_id_ = model_id; +} + +Status OpTask::GetProfilingArgs(std::string &model_name, std::string &op_name, uint32_t &model_id, + uint32_t &block_dim) { + model_name = model_name_; + model_id = model_id_; + block_dim = block_dim_; + GE_CHECK_NOTNULL(op_desc_); + op_name = op_desc_->GetName(); + return SUCCESS; +} +Status OpTask::UpdateRunInfo(const vector &input_desc, const vector &output_desc) { + return UNSUPPORTED; +} +Status OpTask::UpdateArgTable(const SingleOpModelParam ¶m) { + auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param); + auto all_addresses = BuildTaskUtils::JoinAddresses(addresses); + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); + if (arg_num != all_addresses.size()) { + GELOGE(INTERNAL_ERROR, "[%s] arg number mismatches, expect = %zu, but got = %zu", + op_desc_->GetName().c_str(), + arg_num, + all_addresses.size()); + return INTERNAL_ERROR; + } -void OpTask::SetWorkspaceSizes(const vector &workspace_sizes) { workspace_sizes_ = workspace_sizes; } + for (void *addr : all_addresses) { + *arg_base++ = reinterpret_cast(addr); + } + return SUCCESS; +} + +Status OpTask::LaunchKernel(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream) { + return UNSUPPORTED; +} TbeOpTask::~TbeOpTask() { if (sm_desc_ != nullptr) { @@ -126,12 +173,6 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { return RT_FAILED; } GELOGI("[TASK_INFO] %s", this->stub_name_.c_str()); - - size_t input_size = op_desc_->GetInputsSize(); - size_t output_size = op_desc_->GetOutputsSize(); - uint64_t *io_addr = reinterpret_cast(args_.get()); - std::vector io_addrs(io_addr, io_addr + input_size + output_size); - SetIoAddrsForDump(io_addrs); auto status = OpenDump(stream); if (status != SUCCESS) { GELOGE(status, "Open dump failed in the tbe single op %s", this->stub_name_.c_str()); @@ -152,11 +193,12 @@ Status TbeOpTask::UpdateRunInfo(const vector &input_desc, const ve GELOGE(FAILED, "Failed to invoke OpParaCalculate. ret = %u", ret); return FAILED; } - SetWorkspaceSizes(run_info.workspaces); block_dim_ = run_info.block_dim; tiling_data_ = run_info.tiling_data.str(); GELOGD("Done invoking OpParaCalculate successfully. block_dim = %u, tiling size = %zu", block_dim_, tiling_data_.size()); + + GE_CHK_STATUS_RET(AllocateWorkspaces(run_info.workspaces), "Failed to allocate workspaces"); return SUCCESS; } @@ -212,13 +254,54 @@ void TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, s max_tiling_size_ = max_tiling_size; } -Status TbeOpTask::LaunchKernel(const vector &inputs, const vector &outputs, - const vector &workspaces, rtStream_t stream) { +Status TbeOpTask::AllocateWorkspaces(const vector &workspace_sizes) { + static const std::string kPurpose("malloc workspace memory for dynamic op."); + if (workspace_sizes.empty()) { + GELOGD("No need to allocate workspace."); + return SUCCESS; + } + int64_t total_size = 0; + std::vector ws_offsets; + for (auto ws_size : workspace_sizes) { + // alignment and padding should be done in OpParaCalculate + GE_CHK_STATUS_RET_NOLOG(CheckInt64AddOverflow(total_size, ws_size)); + ws_offsets.emplace_back(total_size); + total_size += ws_size; + } + + GELOGD("Total workspace size is %ld", total_size); + GE_CHECK_NOTNULL(stream_resource_); + auto ws_base = stream_resource_->MallocMemory(kPurpose, static_cast(total_size)); + if (ws_base == nullptr) { + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to allocate memory of size: %ld", total_size); + return ACL_ERROR_GE_MEMORY_ALLOCATION; + } + GELOGD("Done allocating workspace memory successfully."); + + for (auto ws_offset : ws_offsets) { + workspaces_.emplace_back(ws_base + ws_offset); + } + + return SUCCESS; +} + +Status TbeOpTask::LaunchKernel(const vector &input_desc, + const vector &input_buffers, + vector &output_desc, + vector &output_buffers, + rtStream_t stream) { + GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo(input_desc, output_desc)); GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); std::vector args; - args.insert(args.end(), inputs.begin(), inputs.end()); - args.insert(args.end(), outputs.begin(), outputs.end()); - args.insert(args.end(), workspaces.begin(), workspaces.end()); + for (auto &buffer : input_buffers) { + args.emplace_back(buffer.data); + } + for (auto &buffer : output_buffers) { + args.emplace_back(buffer.data); + } + for (auto &buffer : workspaces_) { + args.emplace_back(buffer); + } if (tiling_buffer_ != nullptr) { GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); @@ -239,6 +322,14 @@ Status TbeOpTask::LaunchKernel(const vector &inputs, const vector(args_.get()); + arg_count = arg_size_ / sizeof(void *); + if (tiling_buffer_ != nullptr) { + --arg_count; + } +} + AiCpuBaseTask::~AiCpuBaseTask() { if (ext_info_addr_dev_ != nullptr) { (void)rtFree(ext_info_addr_dev_); @@ -278,6 +369,25 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint return SUCCESS; } +Status AiCpuBaseTask::SetInputConst() { + input_is_const_.clear(); + const vector v_is_input_const = op_desc_->GetIsInputConst(); + for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { + const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast(i)); + if (tensor_desc == nullptr) { + GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); + continue; + } + if (i < v_is_input_const.size() && v_is_input_const[i]) { + GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); + input_is_const_.push_back(true); + continue; + } + input_is_const_.push_back(false); + } + return SUCCESS; +} + Status AiCpuBaseTask::UpdateExtInfo(const std::vector &input_desc, std::vector &output_desc, rtStream_t stream) { @@ -288,9 +398,23 @@ Status AiCpuBaseTask::UpdateExtInfo(const std::vector &input_desc, } GE_CHECK_NOTNULL(aicpu_ext_handle_); - for (size_t i = 0; i < num_inputs_; ++i) { - GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(i, input_desc[i]), - "Input[%zu] update input shape failed.", i); + + size_t non_const_index = 0; + for (size_t input_index = 0; input_index < num_inputs_; input_index++) { + if (input_index < input_is_const_.size() && input_is_const_[input_index]) { + // get input_desc from op_desc if const input, num_inputs_ is op_desc_ input_size + auto const_input_desc = op_desc_->MutableInputDesc(static_cast(input_index)); + GE_CHECK_NOTNULL(const_input_desc); + GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(input_index, *const_input_desc), + "Input[%zu] update input shape failed.", input_index); + continue; + } + GE_CHK_BOOL_RET_STATUS(non_const_index < input_desc.size(), PARAM_INVALID, + "Input_desc size is %zu, but get non_const_index is %zu", + input_desc.size(), non_const_index); + GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(input_index, input_desc[non_const_index]), + "Input[%zu] update input shape failed.", input_index); + non_const_index++; } if (unknown_type_ != DEPEND_COMPUTE) { @@ -363,6 +487,41 @@ Status AiCpuBaseTask::UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensor return SUCCESS; } +Status AiCpuBaseTask::UpdateIoAddr(const vector &inputs, const vector &outputs) { + uintptr_t *arg_base = nullptr; + size_t arg_num = 0; + GetIoAddr(arg_base, arg_num); + + // input number and output number was check in ValidateParams + size_t non_const_index = 0; + for (size_t input_index = 0; input_index < num_inputs_; input_index++) { + if (input_index < input_is_const_.size() && input_is_const_[input_index]) { + // const input no need update addr + GE_CHECK_NOTNULL(arg_base); + GELOGD("AICpuTask input[%zu] addr = %u", input_index, *arg_base); + arg_base++; + continue; + } + GE_CHK_BOOL_RET_STATUS(non_const_index < inputs.size(), PARAM_INVALID, + "Input size is %zu, but get non_const_index is %zu", + inputs.size(), non_const_index); + auto addr = inputs[non_const_index].data; + GE_CHECK_NOTNULL(addr); + GELOGD("AICpuTask input[%zu] addr = %p", input_index, addr); + *arg_base++ = reinterpret_cast(addr); + non_const_index++; + } + + for (size_t i = 0; i < outputs.size(); ++i) { + auto addr = outputs[i].data; + GE_CHECK_NOTNULL(addr); + GELOGD("AICpuTask output[%zu] addr = %p", i, addr); + *arg_base++ = reinterpret_cast(addr); + } + + return SUCCESS; +} + AiCpuTask::~AiCpuTask() { FreeHbm(args_); FreeHbm(io_addr_); @@ -384,12 +543,14 @@ AiCpuTask::~AiCpuTask() { } } -const void *AiCpuTask::GetIOAddr() const { return io_addr_; } - Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGD("Start to launch kernel. task = %s", this->op_type_.c_str()); - auto ret = rtMemcpyAsync(workspace_addr_, task_info_.size(), task_info_.data(), task_info_.size(), - RT_MEMCPY_HOST_TO_DEVICE_EX, stream); + auto ret = rtMemcpyAsync(io_addr_, + io_addr_size_, + io_addr_host_.data(), + io_addr_host_.size() * sizeof(void *), + RT_MEMCPY_HOST_TO_DEVICE_EX, + stream); if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync workspace data failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; @@ -538,40 +699,6 @@ Status AiCpuTask::UpdateShapeAndDataByResultSummary(vector &output return SUCCESS; } -Status AiCpuTask::SetIO(const vector &inputs, vector &outputs) { - vector io_addrs; - io_addrs.reserve(num_inputs_ + num_outputs_); - for (size_t i = 0; i < num_inputs_; ++i) { - GE_CHECK_NOTNULL(inputs[i]); - GELOGD("AiCpuTask input[%zu] addr = %p", i, inputs[i]); - io_addrs.emplace_back(reinterpret_cast(inputs[i])); - } - - if (unknown_type_ != DEPEND_COMPUTE) { - for (size_t i = 0; i < num_outputs_; ++i) { - GE_CHECK_NOTNULL(outputs[i]); - GELOGD("AiCpuTask output[%zu] addr = %p", i, outputs[i]); - io_addrs.emplace_back(reinterpret_cast(outputs[i])); - } - } else { - for (size_t i = 0; i < num_outputs_; ++i) { - void *summary_addr = output_summary_[i]; - io_addrs.emplace_back(reinterpret_cast(summary_addr)); - } - } - - if (!io_addrs.empty()) { - auto *dst_io_addr = const_cast(reinterpret_cast(io_addr_)); - GE_CHK_RT_RET(rtMemcpy(dst_io_addr, - sizeof(uint64_t) * io_addrs.size(), - &io_addrs[0], - sizeof(uint64_t) * io_addrs.size(), - RT_MEMCPY_HOST_TO_DEVICE)); - GE_CHECK_NOTNULL(dst_io_addr); - }; - return SUCCESS; -} - Status AiCpuTask::InitForSummaryAndCopy() { if (unknown_type_ != DEPEND_COMPUTE || num_outputs_ == 0) { GELOGI("Unknown_type is %d, output num is %d.", unknown_type_, num_outputs_); @@ -643,17 +770,17 @@ Status AiCpuTask::LaunchKernel(const std::vector &input_desc, std::vector &output_buffers, rtStream_t stream) { GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); - std::vector inputs; - std::vector outputs; - for (auto &buffer : input_buffers) { - inputs.emplace_back(buffer.data); - } - for (auto &buffer : output_buffers) { - outputs.emplace_back(buffer.data); + if (unknown_type_ == DEPEND_COMPUTE) { + std::vector summary_buffers; + for (size_t i = 0; i < num_outputs_; ++i) { + summary_buffers.emplace_back(output_summary_[i], sizeof(aicpu::FWKAdapter::ResultSummary), false); + } + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, summary_buffers)); + } else { + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers)); } - GE_CHK_STATUS_RET_NOLOG(SetIO(inputs, outputs)); - GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); + GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); if (unknown_type_ == DEPEND_SHAPE_RANGE) { GE_CHK_RT_RET(rtStreamSynchronize(stream)); GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); @@ -665,6 +792,17 @@ Status AiCpuTask::LaunchKernel(const std::vector &input_desc, return SUCCESS; } +Status AiCpuTask::UpdateArgTable(const SingleOpModelParam ¶m) { + auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, false); + io_addr_host_ = BuildTaskUtils::JoinAddresses(addresses); + return SUCCESS; +} + +void AiCpuTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { + arg_base = reinterpret_cast(io_addr_host_.data()); + arg_count = io_addr_host_.size(); +} + void AiCpuCCTask::SetKernelArgs(std::unique_ptr args, size_t arg_size) { args_ = std::move(args); arg_size_ = arg_size; @@ -676,9 +814,7 @@ void AiCpuCCTask::SetSoName(const std::string &so_name) { so_name_ = so_name; } void AiCpuCCTask::SetkernelName(const std::string &kernel_Name) { kernel_name_ = kernel_Name; } -void AiCpuCCTask::SetIoAddr(void *io_addr) { io_addr_ = io_addr; } - -const void *AiCpuCCTask::GetIOAddr() const { return io_addr_; } +void AiCpuCCTask::SetIoAddr(uintptr_t *io_addr) { io_addr_ = io_addr; } const void *AiCpuCCTask::GetArgs() const { return args_.get(); } @@ -701,12 +837,6 @@ Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { return ret; } GELOGD("Invoke rtCpuKernelLaunch succeeded"); - - size_t input_size = op_desc_->GetInputsSize(); - size_t output_size = op_desc_->GetOutputsSize(); - uint64_t *io_addr = reinterpret_cast(io_addr_); - std::vector io_addrs (io_addr, io_addr + input_size + output_size); - SetIoAddrsForDump(io_addrs); auto status = OpenDump(stream); if (status != SUCCESS) { GELOGE(status, "Open dump failed in the aicpucc single op %s", this->kernel_name_.c_str()); @@ -721,24 +851,9 @@ Status AiCpuCCTask::LaunchKernel(const std::vector &input_desc, std::vector &output_desc, std::vector &output_buffers, rtStream_t stream) { - GE_CHK_BOOL_RET_STATUS(unknown_type_ != DEPEND_COMPUTE, FAILED, - "AiCpuCCTask unknown type[%d] is depend compute, it's not supported now.", - unknown_type_); - GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); - - size_t arg_index = 0; - auto *task_io_addr = reinterpret_cast(io_addr_); - GE_CHECK_NOTNULL(task_io_addr); - for (auto &input : input_buffers) { - task_io_addr[arg_index++] = reinterpret_cast(input.data); - } - for (auto &output : output_buffers) { - task_io_addr[arg_index++] = reinterpret_cast(output.data); - } - + GE_CHK_STATUS_RET_NOLOG(UpdateIoAddr(input_buffers, output_buffers)); GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); - if (unknown_type_ == DEPEND_SHAPE_RANGE) { GE_CHK_RT_RET(rtStreamSynchronize(stream)); GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); @@ -746,4 +861,9 @@ Status AiCpuCCTask::LaunchKernel(const std::vector &input_desc, return SUCCESS; } + +void AiCpuCCTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { + arg_base = io_addr_; + arg_count = io_addr_num_; +} } // namespace ge diff --git a/ge/single_op/task/op_task.h b/ge/single_op/task/op_task.h index 65c77800..761697cb 100644 --- a/ge/single_op/task/op_task.h +++ b/ge/single_op/task/op_task.h @@ -32,64 +32,46 @@ #include "init/gelib.h" namespace ge { -enum OpTaskType { - OP_TASK_TBE = 0, - OP_TASK_AICPU, - OP_TASK_AICPUCC, - OP_TASK_INVALID, -}; - +class StreamResource; +struct SingleOpModelParam; class OpTask { public: OpTask() = default; virtual ~OpTask() = default; virtual Status LaunchKernel(rtStream_t stream) = 0; virtual Status UpdateRunInfo(const vector &input_desc, - const vector &output_desc) { - return UNSUPPORTED; - } - virtual Status LaunchKernel(const std::vector &inputs, - const std::vector &outputs, - const std::vector &workspaces, - rtStream_t stream) { - return UNSUPPORTED; - } - virtual OpTaskType GetOpTaskType() = 0; - virtual const void *GetIOAddr() const = 0; - const vector &GetWorkspaceSizes() const; - void SetWorkspaceSizes(const vector &workspace_sizes); + const vector &output_desc); + virtual Status UpdateArgTable(const SingleOpModelParam ¶m); + void SetModelArgs(std::string model_name, uint32_t model_id); + Status GetProfilingArgs(std::string &model_name, std::string &op_name, uint32_t &model_id, uint32_t &block_dim); const OpDescPtr &GetOpdesc() const {return op_desc_;} Status OpenDump(rtStream_t stream); - void SetIoAddrsForDump(const vector &io_addrs_for_dump) { - io_addrs_for_dump_ = io_addrs_for_dump; - } + virtual void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) = 0; virtual Status LaunchKernel(const std::vector &input_desc, const std::vector &input_buffers, std::vector &output_desc, std::vector &output_buffers, - rtStream_t stream) { - return UNSUPPORTED; - } + rtStream_t stream); - private: - std::vector workspace_sizes_; protected: DumpProperties dump_properties_; DumpOp dump_op_; OpDescPtr op_desc_; - std::vector io_addrs_for_dump_; + std::string model_name_; + uint32_t model_id_ = 0; + uint32_t block_dim_ = 1; }; class TbeOpTask : public OpTask { public: ~TbeOpTask() override; Status LaunchKernel(rtStream_t stream) override; - OpTaskType GetOpTaskType() override { - return OP_TASK_TBE; - } - const void *GetIOAddr() const override { - return nullptr; - } + Status LaunchKernel(const std::vector &input_desc, + const std::vector &input_buffers, + std::vector &output_desc, + std::vector &output_buffers, + rtStream_t stream) override; + void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; void SetSmDesc(void *sm_desc); void SetStubFunc(const std::string &name, const void *stub_func); void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim, const OpDescPtr &op_desc); @@ -97,31 +79,29 @@ class TbeOpTask : public OpTask { Status UpdateRunInfo(const vector &input_desc, const vector &output_desc) override; - Status LaunchKernel(const vector &inputs, - const vector &outputs, - const vector &workspaces, - rtStream_t stream) override; - const void *GetArgs() const; size_t GetArgSize() const; const std::string &GetStubName() const; void EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, size_t max_tiling_size); private: + friend class SingleOpModel; static Status UpdateTensorDesc(const GeTensorDesc &src_tensor, GeTensorDesc &dst_tensor); Status UpdateNodeByShape(const vector &input_desc, const vector &output_desc); + Status AllocateWorkspaces(const std::vector &workspace_sizes); const void *stub_func_ = nullptr; std::unique_ptr args_; size_t arg_size_ = 0; - uint32_t block_dim_ = 1; void *sm_desc_ = nullptr; std::string stub_name_; + StreamResource *stream_resource_ = nullptr; void *tiling_buffer_ = nullptr; uint32_t max_tiling_size_ = 0; std::string tiling_data_; + std::vector workspaces_; NodePtr node_; }; @@ -129,9 +109,11 @@ class AiCpuBaseTask : public OpTask { public: AiCpuBaseTask() = default; ~AiCpuBaseTask() override; - const UnknowShapeOpType GetUnknownType() const { return unknown_type_; } + UnknowShapeOpType GetUnknownType() const { return unknown_type_; } protected: + Status UpdateIoAddr(const std::vector &inputs, const std::vector &outputs); + Status SetInputConst(); Status SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id); Status UpdateExtInfo(const std::vector &input_desc, @@ -146,6 +128,7 @@ class AiCpuBaseTask : public OpTask { UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; std::unique_ptr aicpu_ext_handle_; void *ext_info_addr_dev_ = nullptr; + vector input_is_const_; }; class AiCpuTask : public AiCpuBaseTask { @@ -154,10 +137,8 @@ class AiCpuTask : public AiCpuBaseTask { ~AiCpuTask() override; Status LaunchKernel(rtStream_t stream) override; - OpTaskType GetOpTaskType() override { - return OP_TASK_AICPU; - } - const void *GetIOAddr() const override; + Status UpdateArgTable(const SingleOpModelParam ¶m) override; + void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; Status LaunchKernel(const std::vector &input_desc, const std::vector &input_buffers, @@ -167,8 +148,6 @@ class AiCpuTask : public AiCpuBaseTask { Status SetMemCopyTask(const domi::KernelExDef &kernel_def); private: - Status SetIO(const vector &inputs, vector &outputs); - // for copy task. Status InitForSummaryAndCopy(); Status UpdateShapeAndDataByResultSummary(vector &output_desc, @@ -184,27 +163,31 @@ class AiCpuTask : public AiCpuBaseTask { friend class AiCpuTaskBuilder; void *workspace_addr_ = nullptr; std::string task_info_; - // device addr + // device addr void *args_ = nullptr; size_t arg_size_ = 0; std::string op_type_; // device addr void *io_addr_ = nullptr; + size_t io_addr_size_ = 0; + + // host addr + std::vector io_addr_host_; bool dynamic_flag_ = false; // for copy task - void *copy_task_args_buf_; - void *copy_workspace_buf_; + void *copy_task_args_buf_ = nullptr; + void *copy_workspace_buf_ = nullptr; std::vector output_summary_; std::vector output_summary_host_; - void *copy_ioaddr_dev_; + void *copy_ioaddr_dev_ = nullptr; - void *copy_input_release_flag_dev_; - void *copy_input_data_size_dev_; - void *copy_input_src_dev_; - void *copy_input_dst_dev_; + void *copy_input_release_flag_dev_ = nullptr; + void *copy_input_data_size_dev_ = nullptr; + void *copy_input_src_dev_ = nullptr; + void *copy_input_dst_dev_ = nullptr; vector out_shape_hbm_; uint64_t kernel_id_ = 0; @@ -218,13 +201,12 @@ class AiCpuCCTask : public AiCpuBaseTask { AiCpuCCTask &operator=(const AiCpuCCTask &) = delete; Status LaunchKernel(rtStream_t stream) override; - OpTaskType GetOpTaskType() override { return OP_TASK_AICPUCC; } - const void *GetIOAddr() const override; + void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; const void *GetArgs() const; void SetKernelArgs(std::unique_ptr args, size_t arg_size); void SetSoName(const std::string &so_name); void SetkernelName(const std::string &kernel_Name); - void SetIoAddr(void *io_addr); + void SetIoAddr(uintptr_t *io_addr); size_t GetArgSize() const; Status LaunchKernel(const std::vector &input_desc, @@ -239,9 +221,9 @@ private: std::string kernel_name_; std::unique_ptr args_; size_t arg_size_ = 0; - uint32_t block_dim_ = 1; void *sm_desc_ = nullptr; - void *io_addr_ = nullptr; + uintptr_t *io_addr_ = nullptr; + size_t io_addr_num_ = 0; bool is_custom_ = false; uint32_t dump_flag_ = RT_KERNEL_DEFAULT; }; diff --git a/ge/single_op/task/tbe_task_builder.cc b/ge/single_op/task/tbe_task_builder.cc index e06a08c6..594352aa 100644 --- a/ge/single_op/task/tbe_task_builder.cc +++ b/ge/single_op/task/tbe_task_builder.cc @@ -173,7 +173,8 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam auto tbe_kernel = GetTbeKernel(op_desc_); if (tbe_kernel == nullptr) { - GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", op_desc_->GetName().c_str()); + GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", + op_desc_->GetName().c_str()); return ACL_ERROR_GE_INTERNAL_ERROR; } diff --git a/ge/single_op/task/tbe_task_builder.h b/ge/single_op/task/tbe_task_builder.h old mode 100755 new mode 100644 diff --git a/ge/stub/gen_stubapi.py b/ge/stub/gen_stubapi.py index f2a6a287..1476d505 100644 --- a/ge/stub/gen_stubapi.py +++ b/ge/stub/gen_stubapi.py @@ -1,3 +1,10 @@ +#!/usr/bin/python3.7 +# -*- coding: UTF-8 -*- +#------------------------------------------------------------------- +# Purpose: +# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. +#------------------------------------------------------------------- + import os import re import sys @@ -64,7 +71,7 @@ max_code_len_per_line = 100 when DEBUG on """ white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", - "ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", "caffe_parser.h"] + "ge_ir_build.h", "ge_api.h", "tensorflow_parser.h", "caffe_parser.h"] include_dir_key_words = ["ge", "graph", "parser"] DEBUG = True diff --git a/inc/external/acl/acl_base.h b/inc/external/acl/acl_base.h index debadcfd..c1341d59 100644 --- a/inc/external/acl/acl_base.h +++ b/inc/external/acl/acl_base.h @@ -225,6 +225,29 @@ ACL_FUNC_VISIBILITY aclError aclDestroyDataBuffer(const aclDataBuffer *dataBuffe /** * @ingroup AscendCL + * @brief update new data of aclDataBuffer + * + * @param dataBuffer [OUT] pointer to aclDataBuffer + * @li The old data need to be released by the user, otherwise it may occur memory leak leakage + * call aclGetDataBufferAddr interface to get old data address + * call aclrtFree interface to release memory + * + * @param data [IN] pointer to new data + * @li Need to be managed by the user, + * call aclrtMalloc interface to apply for memory, + * call aclrtFree interface to release memory + * + * @param size [IN] size of data in bytes + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + * + * @see aclrtMalloc | aclrtFree | aclGetDataBufferAddr + */ +ACL_FUNC_VISIBILITY aclError aclUpdateDataBuffer(aclDataBuffer *dataBuffer, void *data, size_t size); + +/** + * @ingroup AscendCL * @brief get data address from aclDataBuffer * * @param dataBuffer [IN] pointer to the data of aclDataBuffer @@ -549,6 +572,19 @@ ACL_FUNC_VISIBILITY aclError aclSetTensorDynamicInput(aclTensorDesc *desc, const /** * @ingroup AscendCL + * @brief Set const data specified by the tensor description + * + * @param desc [OUT] pointer to the instance of aclTensorDesc + * @param dataBuffer [IN] pointer to the const databuffer + * @param length [IN] the length of const databuffer + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclSetTensorConst(aclTensorDesc *desc, void *dataBuffer, size_t length); + +/** + * @ingroup AscendCL * @brief an interface for users to output APP logs * * @param logLevel [IN] the level of current log diff --git a/inc/external/acl/acl_prof.h b/inc/external/acl/acl_prof.h index bfb8a68b..97a81b53 100644 --- a/inc/external/acl/acl_prof.h +++ b/inc/external/acl/acl_prof.h @@ -32,12 +32,11 @@ extern "C" { #define ACL_PROF_MAX_OP_TYPE_LEN 65 typedef enum { - ACL_AICORE_ARITHMATIC_THROUGHPUT = 0, - ACL_AICORE_PIPELINE = 1, - ACL_AICORE_SYNCHRONIZATION = 2, - ACL_AICORE_MEMORY = 3, - ACL_AICORE_INTERNAL_MEMORY = 4, - ACL_AICORE_STALL = 5, + ACL_AICORE_ARITHMETIC_UTILIZATION = 0, + ACL_AICORE_PIPE_UTILIZATION = 1, + ACL_AICORE_MEMORY_BANDWIDTH = 2, + ACL_AICORE_L0B_AND_WIDTH = 3, + ACL_AICORE_RESOURCE_CONFLICT_RATIO = 4, ACL_AICORE_NONE = 0xFF } aclprofAicoreMetrics; @@ -290,6 +289,32 @@ ACL_FUNC_VISIBILITY uint64_t aclprofGetOpDuration(const void *opInfo, size_t opI */ ACL_FUNC_VISIBILITY size_t aclprofGetModelId(const void *opInfo, size_t opInfoLen, uint32_t index); +/** + * @ingroup AscendCL + * @brief get cube ops from subscription data + * + * @param opInfo [IN] pointer to subscription data + * @param opInfoLen [IN] memory size of subscription data + * @param index [IN] index of op array in opInfo + * + * @retval cube ops of subscription data + * @retval 0 for failed + */ +ACL_FUNC_VISIBILITY uint64_t aclprofGetOpCubeOps(const void *opInfo, size_t opInfoLen, uint32_t index); + +/** + * @ingroup AscendCL + * @brief get vector ops from subscription data + * + * @param opInfo [IN] pointer to subscription data + * @param opInfoLen [IN] memory size of subscription data + * @param index [IN] index of op array in opInfo + * + * @retval vector ops of subscription data + * @retval 0 for failed + */ +ACL_FUNC_VISIBILITY uint64_t aclprofGetOpVectorOps(const void *opInfo, size_t opInfoLen, uint32_t index); + #ifdef __cplusplus } #endif diff --git a/inc/external/acl/error_codes/rt_error_codes.h b/inc/external/acl/error_codes/rt_error_codes.h index 2dd2c70c..e3ec3a3c 100644 --- a/inc/external/acl/error_codes/rt_error_codes.h +++ b/inc/external/acl/error_codes/rt_error_codes.h @@ -1,91 +1,101 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ -#define __INC_EXTERNEL_RT_ERROR_CODES_H__ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -static const int32_t ACL_RT_SUCCESS = 0; // success - -static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid -static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id -static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null -static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context -static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context -static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model -static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid -static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal -static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned -static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed -static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed -static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream -static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread -static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set -static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create -static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream -static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type - -static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPROT = 207000; // feature not support -static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error -static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error - -static const int32_t ACL_ERROR_RT_INTERNEL_ERROR = 507000; // runtime internel error -static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error -static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream -static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream -static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete -static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence -static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete -static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error -static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error -static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support -static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat -static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed -static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout -static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error -static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout -static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception -static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception -static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout -static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception -static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error -static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error -static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error -static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error -static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal -static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering -static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init -static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data -static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error -static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate -static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed -static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed -static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context -static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out -static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error - -static const int32_t ACL_ERROR_RT_DRV_INTERNEL_ERROR = 507899; // drv internel error - -#ifdef __cplusplus -} -#endif - -#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ +#define __INC_EXTERNEL_RT_ERROR_CODES_H__ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +static const int32_t ACL_RT_SUCCESS = 0; // success + +static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid +static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id +static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null +static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context +static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context +static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal +static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned +static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed +static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed +static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream +static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread +static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set +static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create +static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream +static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type +static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle +static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type + +static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support +static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error +static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error +static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow +static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device +static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail +static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission +static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource +static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource +static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource +static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource + +static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error +static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error +static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream +static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream +static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete +static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence +static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete +static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error +static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error +static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support +static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat +static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed +static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout +static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error +static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout +static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception +static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception +static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout +static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception +static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error +static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error +static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error +static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error +static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal +static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering +static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init +static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data +static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error +static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate +static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed +static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed +static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context +static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out +static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error + +static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error + +#ifdef __cplusplus +} +#endif + +#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ diff --git a/inc/external/acl/ops/acl_dvpp.h b/inc/external/acl/ops/acl_dvpp.h index 32a21e91..1a0f582d 100644 --- a/inc/external/acl/ops/acl_dvpp.h +++ b/inc/external/acl/ops/acl_dvpp.h @@ -130,6 +130,23 @@ enum acldvppChannelMode { DVPP_CHNMODE_VPC = 1, DVPP_CHNMODE_JPEGD = 2, DVPP_CHN // Supported Border Type enum acldvppBorderType { BORDER_CONSTANT = 0, BORDER_REPLICATE, BORDER_REFLECT, BORDER_REFLECT_101 }; +// Venc parameter type +enum aclvencChannelDescParamType { + ACL_VENC_THREAD_ID_UINT64 = 0, + ACL_VENC_CALLBACK_PTR, + ACL_VENC_PIXEL_FORMAT_UINT32, + ACL_VENC_ENCODE_TYPE_UINT32, + ACL_VENC_PIC_WIDTH_UINT32, + ACL_VENC_PIC_HEIGHT_UINT32, + ACL_VENC_KEY_FRAME_INTERVAL_UINT32, + ACL_VENC_BUF_ADDR_PTR, + ACL_VENC_BUF_SIZE_UINT32, + ACL_VENC_RC_MODE_UINT32, + ACL_VENC_SRC_RATE_UINT32, + ACL_VENC_MAX_BITRATE_UINT32, + ACL_VENC_MAX_IP_PROP_UINT32 +}; + /** * @ingroup AscendCL * @brief alloc device memory for dvpp. @@ -1039,6 +1056,21 @@ ACL_FUNC_VISIBILITY aclError aclvencSetChannelDescMaxBitRate(aclvencChannelDesc /** * @ingroup AscendCL + * @brief Set venc parameter for venc channel desc. + * + * @param channelDesc [OUT] venc channel desc + * @param paramType [IN] parameter type + * @param length [IN] parameter length + * @param param [IN] pointer to parameter value + * + * @retval ACL_SUCCESS for success, other for failure + */ +ACL_FUNC_VISIBILITY aclError aclvencSetChannelDescParam(aclvencChannelDesc *channelDesc, + aclvencChannelDescParamType paramType, size_t length, + const void *param); + +/** + * @ingroup AscendCL * @brief Get output buffer address for venc channel desc. * * @param channelDesc[IN] venc channel desc @@ -1172,6 +1204,23 @@ ACL_FUNC_VISIBILITY uint32_t aclvencGetChannelDescMaxBitRate(const aclvencChanne /** * @ingroup AscendCL + * + * @brief Get venc parameter for venc channel desc. + * + * @param channelDesc [IN] venc channel desc + * @param paramType [IN] parameter type + * @param length [IN] parameter length + * @param paramRetSize [OUT] pointer to parameter real length + * @param param [OUT] pointer to parameter value + * + * @retval ACL_SUCCESS for success, other for failure + */ +ACL_FUNC_VISIBILITY aclError aclvencGetChannelDescParam(const aclvencChannelDesc *channelDesc, + aclvencChannelDescParamType paramType, size_t length, + size_t *paramRetSize, void *param); + +/** + * @ingroup AscendCL * @brief get forced restart of I-frame interval from config * * @param config [IN] venc frame config diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 374a816a..cce17f93 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -369,6 +369,7 @@ static const char *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); // for interface: aclgrphBuildModel +#ifdef __GNUC__ const std::set ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, @@ -424,6 +425,7 @@ const std::set global_options = {CORE_TYPE, DEBUG_DIR, OP_COMPILER_CACHE_DIR, OP_COMPILER_CACHE_MODE}; +#endif } // namespace ir_option } // namespace ge diff --git a/inc/external/ge/ge_ir_build.h b/inc/external/ge/ge_ir_build.h index 778ec21d..182c0444 100644 --- a/inc/external/ge/ge_ir_build.h +++ b/inc/external/ge/ge_ir_build.h @@ -24,9 +24,9 @@ #include "graph/ge_error_codes.h" namespace { -#define IR_MAJOR_VERSION (int(1)) -#define IR_MINOR_VERSION (int(0)) -#define IR_PATCH_VERSION (int(0)) +const int IR_MAJOR_VERSION = 1; +const int IR_MINOR_VERSION = 0; +const int IR_PATCH_VERSION = 0; } // namespace namespace ge { @@ -121,5 +121,20 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph); * @retval OtherValues Failure */ graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); + +/** + * @ingroup AscendCL + * @brief create single op graph + * + * @param op_type[IN] the op_type + * @param inputs[IN] the inputdesc + * @param outputs[IN] the outputdesc + * @param graph[OUT] the graph + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector &inputs, + const std::vector &outputs, Graph &graph); + }; // namespace ge #endif // INC_EXTERNAL_GE_IR_BUILD_H_ diff --git a/inc/external/ge/ge_prof.h b/inc/external/ge/ge_prof.h deleted file mode 100644 index 658cea76..00000000 --- a/inc/external/ge/ge_prof.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GE_GE_PROF_H_ -#define INC_EXTERNAL_GE_GE_PROF_H_ - -#include -#include -#include - -#include "ge/ge_api_error_codes.h" - -namespace ge { -enum ProfDataTypeConfig { - kProfTaskTime = 0x0002, - kProfAiCoreMetrics = 0x0004, - kProfAicpuTrace = 0x0008, - kProfTrainingTrace = 0x0800, - kProfHcclTrace = 0x1000 -}; - -enum ProfilingAicoreMetrics { - kAicoreArithmaticThroughput = 0, - kAicorePipeline = 1, - kAicoreSynchronization = 2, - kAicoreMemory = 3, - kAicoreInternalMemory = 4, - kAicoreStall = 5 -}; - -typedef struct ProfAicoreEvents ProfAicoreEvents; -typedef struct aclgrphProfConfig aclgrphProfConfig; - -/// -/// @ingroup AscendCL -/// @brief Initialize the profiling and set profiling configuration path -/// @param [in] profiler_path: configuration path of profiling -/// @param [in] length: length of configuration path -/// @return Status result of function -/// -Status aclgrphProfInit(const char *profiler_path, uint32_t length); - -/// -/// @ingroup AscendCL -/// @brief Finalize profiling -/// @return Status result of function -/// -Status aclgrphProfFinalize(); - -/// -/// @ingroup AscendCL -/// @brief Create data of type aclgrphProfConfig -/// @param [in] deviceid_list: device id list -/// @param [in] device_nums: device numbers -/// @param [in] aicore_metrics: type of aicore metrics -/// @param [in] aicore_events: pointer to aicore events be reserved, only support NULL now -/// @param [in] data_type_config: modules need profiling -/// @return Status result of function -/// -aclgrphProfConfig *aclgrphProfCreateConfig(uint32_t *deviceid_list, uint32_t device_nums, - ProfilingAicoreMetrics aicore_metrics, ProfAicoreEvents *aicore_events, - uint64_t data_type_config); - -/// -/// @ingroup AscendCL -/// @brief Destroy data of type aclgrphProfConfig -/// @param [in] profiler_config: config of profiling -/// @return Status result of function -/// -Status aclgrphProfDestroyConfig(aclgrphProfConfig *profiler_config); - -/// -/// @ingroup AscendCL -/// @brief Start profiling of modules which is configured by profiler config -/// @param [in] profiler_config: config of profiling -/// @return Status result of function -/// -Status aclgrphProfStart(aclgrphProfConfig *profiler_config); - -/// -/// @ingroup AscendCL -/// @brief Stop profiling of modules which is configured by profiler config -/// @param [in] profiler_config: config of profiling -/// @return Status result of function -/// -Status aclgrphProfStop(aclgrphProfConfig *profiler_config); -} // namespace ge - -#endif // INC_EXTERNAL_GE_GE_PROF_H_ diff --git a/inc/external/runtime/rt_error_codes.h b/inc/external/runtime/rt_error_codes.h index 2dd2c70c..e3ec3a3c 100644 --- a/inc/external/runtime/rt_error_codes.h +++ b/inc/external/runtime/rt_error_codes.h @@ -1,91 +1,101 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ -#define __INC_EXTERNEL_RT_ERROR_CODES_H__ - -#include - -#ifdef __cplusplus -extern "C" { -#endif - -static const int32_t ACL_RT_SUCCESS = 0; // success - -static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid -static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id -static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null -static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context -static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context -static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model -static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid -static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal -static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned -static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed -static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed -static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream -static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread -static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set -static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create -static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream -static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type - -static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPROT = 207000; // feature not support -static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error -static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error - -static const int32_t ACL_ERROR_RT_INTERNEL_ERROR = 507000; // runtime internel error -static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error -static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream -static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream -static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete -static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence -static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete -static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error -static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error -static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support -static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat -static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed -static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout -static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error -static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout -static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception -static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception -static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout -static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception -static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error -static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error -static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error -static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error -static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal -static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering -static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init -static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data -static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error -static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate -static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed -static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed -static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context -static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out -static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error - -static const int32_t ACL_ERROR_RT_DRV_INTERNEL_ERROR = 507899; // drv internel error - -#ifdef __cplusplus -} -#endif - -#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ +#define __INC_EXTERNEL_RT_ERROR_CODES_H__ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +static const int32_t ACL_RT_SUCCESS = 0; // success + +static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid +static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id +static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null +static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context +static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context +static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal +static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned +static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed +static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed +static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream +static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread +static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set +static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create +static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream +static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type +static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle +static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type + +static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support +static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error +static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error +static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow +static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device +static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail +static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission +static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource +static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource +static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource +static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource + +static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error +static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error +static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream +static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream +static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete +static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence +static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete +static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error +static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error +static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support +static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat +static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed +static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout +static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error +static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout +static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception +static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception +static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout +static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception +static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error +static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error +static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error +static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error +static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal +static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering +static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init +static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data +static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error +static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate +static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed +static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed +static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context +static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out +static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error + +static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error + +#ifdef __cplusplus +} +#endif + +#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ diff --git a/inc/framework/common/fmk_error_codes.h b/inc/framework/common/fmk_error_codes.h index ec1f26d0..358fca04 100644 --- a/inc/framework/common/fmk_error_codes.h +++ b/inc/framework/common/fmk_error_codes.h @@ -23,10 +23,6 @@ #include "framework/common/fmk_types.h" #include "register/register_error_codes.h" -#define MODID_OMG 1 // OMG module ID -#define MODID_OME 2 // OME module ID -#define MODID_CALIBRATION 3 // Calibration module ID - // Each module uses the following four macros to define error codes: #define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value) #define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value) @@ -37,6 +33,10 @@ // Interface for Obtaining Error Code Description #define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value) +const int MODID_OMG = 1; // OMG module ID +const int MODID_OME = 2; // OME module ID +const int MODID_CALIBRATION = 3; // Calibration module ID + namespace domi { class StatusFactory { public: diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index 949d8b4c..7867e63d 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -25,6 +25,7 @@ #include "common/types.h" #include "graph/model.h" #include "model/ge_model.h" +#include "model/ge_root_model.h" namespace ge { class ModelHelper { @@ -32,17 +33,22 @@ class ModelHelper { ModelHelper() = default; ~ModelHelper(); - Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, - const std::string &output_file, ge::ModelBufferData &model); - Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file); + Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file, + ge::ModelBufferData &model); + Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const SaveParam &save_param, const string &output_file, + ModelBufferData &model, bool is_unknown_shape); + Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); Status LoadModel(const ge::ModelData &model_data); - Status GetModelBufferData(ge::ModelBufferData& model); + Status LoadRootModel(const ge::ModelData &model_data); + Status GetModelBufferData(ge::ModelBufferData &model); - const ModelFileHeader* GetFileHeader() const { return file_header_; } + const ModelFileHeader *GetFileHeader() const { return file_header_; } GeModelPtr GetGeModel(); + GeRootModelPtr GetGeRootModel(); void SetSaveMode(bool val) { is_offline_ = val; } bool GetSaveMode(void) const { return is_offline_; } + bool GetModelType() const { return is_unknown_shape_model_; }; Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name); Status GetModelNameFromMergedGraphName(const std::string &graph_name, std::string &model_name); @@ -50,24 +56,46 @@ class ModelHelper { private: bool is_assign_model_ = false; bool is_offline_ = true; - ModelFileHeader* file_header_ = nullptr; + bool is_unknown_shape_model_ = false; + ModelFileHeader *file_header_ = nullptr; // Encrypted model need delete temp model and unencrypted model need not delete model uint8_t *model_addr_tmp_ = nullptr; uint32_t model_len_tmp_ = 0; GeModelPtr model_; + GeRootModelPtr root_model_; - ModelHelper(const ModelHelper&); - ModelHelper& operator=(const ModelHelper&); - Status GenerateGeModel(OmFileLoadHelper& om_load_helper); - Status LoadModelData(OmFileLoadHelper& om_load_helper); - void SetModelToGeModel(ge::Model& model); - Status LoadWeights(OmFileLoadHelper& om_load_helper); - Status LoadTask(OmFileLoadHelper& om_load_helper); - Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); - Status LoadCustAICPUKernelStore(OmFileLoadHelper& om_load_helper); + ModelHelper(const ModelHelper &); + ModelHelper &operator=(const ModelHelper &); + Status GenerateGeModel(OmFileLoadHelper &om_load_helper); + Status GenerateGeRootModel(OmFileLoadHelper &om_load_helper); + Status LoadModelData(OmFileLoadHelper &om_load_helper); + void SetModelToGeModel(ge::Model &model); + Status LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index); + Status LoadWeights(OmFileLoadHelper &om_load_helper); + Status LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index); + Status LoadTask(OmFileLoadHelper &om_load_helper); + Status LoadTask(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index); + Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); + Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index); + Status LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper); + Status LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index); Status ReleaseLocalModelData() noexcept; - Status SaveModelPartition(std::shared_ptr& om_file_save_helper, - ModelPartitionType type, const uint8_t* data, size_t size); + Status SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, + const uint8_t *data, size_t size, size_t model_index); + Status SaveModelDef(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &model_buffer, size_t model_index = 0); + Status SaveModelWeights(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_index = 0); + Status SaveModelTbeKernel(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_index = 0); + Status SaveModelCustAICPU(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_index = 0); + Status SaveModelTaskDef(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &task_buffer, size_t model_index = 0); + Status SaveModelHeader(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + size_t model_num = 1); + Status SaveAllModelPartiton(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &model_buffer, Buffer &task_buffer, size_t model_index = 0); }; } // namespace ge #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ diff --git a/inc/framework/common/helper/om_file_helper.h b/inc/framework/common/helper/om_file_helper.h index 4ca54b50..98ad55d7 100644 --- a/inc/framework/common/helper/om_file_helper.h +++ b/inc/framework/common/helper/om_file_helper.h @@ -32,14 +32,14 @@ using std::vector; namespace ge { struct ModelPartition { ModelPartitionType type; - uint8_t* data = 0; + uint8_t *data = 0; uint32_t size = 0; }; struct OmFileContext { std::vector partition_datas_; std::vector partition_table_; - uint32_t model_data_len_; + uint32_t model_data_len_ = 0; }; struct SaveParam { @@ -57,15 +57,23 @@ class OmFileLoadHelper { Status Init(uint8_t *model_data, const uint32_t model_data_size); + Status Init(uint8_t *model_data, const uint32_t model_data_size, uint32_t model_num); + Status GetModelPartition(ModelPartitionType type, ModelPartition &partition); + Status GetModelPartition(ModelPartitionType type, ModelPartition &partition, size_t model_index); + OmFileContext context_; + vector model_contexts_; + private: Status CheckModelValid(const ge::ModelData &model) const; Status LoadModelPartitionTable(uint8_t *model_data, const uint32_t model_data_size); + Status LoadModelPartitionTable(uint8_t *model_data, const uint32_t model_data_size, uint32_t model_num); + bool is_inited_{false}; }; @@ -79,15 +87,23 @@ class OmFileSaveHelper { Status AddPartition(ModelPartition &partition); + Status AddPartition(ModelPartition &partition, size_t cur_index); + const std::vector &GetModelPartitions() const; - Status SaveModel(const SaveParam &save_param, const char *target_file, - ge::ModelBufferData& model, bool is_offline = true); + Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model, + bool is_offline = true); Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); + vector model_contexts_; + ModelFileHeader model_header_; OmFileContext context_; + + ModelPartitionTable *GetPartitionTable(size_t cur_ctx_index); + + Status SaveRootModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, bool is_offline); }; } // namespace ge #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ diff --git a/inc/framework/common/op/ge_op_utils.h b/inc/framework/common/op/ge_op_utils.h index 4718b180..5c97b4c0 100644 --- a/inc/framework/common/op/ge_op_utils.h +++ b/inc/framework/common/op/ge_op_utils.h @@ -17,7 +17,6 @@ #ifndef INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ #define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ -#include #include #include @@ -32,7 +31,6 @@ #include "proto/insert_op.pb.h" namespace ge { -using namespace cce; using domi::Status; // Add Sub Mul @@ -76,18 +74,7 @@ class OpUtils { static inline bool CheckEnumValid(int32_t check_value, int32_t min_enum_value, int32_t max_enum_value) { return check_value < min_enum_value ? false : (check_value >= max_enum_value ? false : true); } - /// - /// @ingroup domi_omg - /// @brief Convert the dimension of array according to different format - /// @param [in] src_format src_shape format - /// @param [in] src Dimension array to be converted - /// @param [in] dst_format Target format after conversion - /// @param [out] dst Dimension array after conversion - /// @return SUCCESS success - /// @return FAILED fail - /// - static bool ConvertDim(ccTensorFormat_t src_format, const std::vector &src, ccTensorFormat_t dst_format, - std::vector &dst); + /// /// @ingroup domi_omg /// @brief Determine whether to manually calculate the tensor size based on the values of format and dim @@ -97,73 +84,6 @@ class OpUtils { /// @return false skip /// static bool IsComputDimsSize(const int32_t format, const uint32_t real_dim_cnt); - /// - /// @ingroup domi_ome - /// @brief Initialize the tensor description, which is used for input and output. - /// @param [in] model_tensor Tensor information defined by the offline model - /// @param [out] cc_tensor Tensor definition used by CC - /// @return SUCCESS success - /// @return FAILED fail - /// - static Status InitTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccTensorDescriptor_t &cc_tensor); - /// - /// @ingroup domi_ome - /// @brief Initialize the tensor description, which is used for input and output. - /// @param [in] model_tensor Tensor information defined by the offline model - /// @param [in] dst_data_type data_type of the target cc_tensor - /// @param [out] cc_tensor Tensor definition used by CC - /// @return SUCCESS success - /// @return FAILED fail - /// - static Status InitTensorDescriptor(const ge::GeTensorDesc &model_tensor, int32_t dst_data_type, - ccTensorDescriptor_t &cc_tensor); - /// - /// @ingroup domi_ome - /// @brief Initialize the tensor description for bias. - /// @param [in] model_tensor Tensor information defined by the offline model - /// @param [out] cc_tensor Tensor definition used by CC - /// @return SUCCESS success - /// @return FAILED fail - /// - /// - static Status InitTensorDescriptor(const ge::GeTensor &model_tensor, ccTensorDescriptor_t &cc_tensor); - /// - /// @ingroup domi_ome - /// @brief Initialize the tensor description for bias. - /// @param [in] model_tensor Tensor information defined by the offline model - /// @param [in] dst_data_type data_type of the target cc_tensor - /// @param [out] cc_tensor Tensor definition used by CC - /// @return SUCCESS success - /// @return FAILED fail - /// - static Status InitTensorDescriptor(const ge::GeTensor &model_tensor, int32_t dst_data_type, - ccTensorDescriptor_t &cc_tensor); - - static Status InitTensorDescriptor(int32_t format, int32_t data_type, const std::vector &dim, - ccTensorDescriptor_t &cc_tensor, uint32_t real_dim_cnt = 4); - /// - /// @ingroup domi_ome - /// @brief Destroys a tensor - /// @param [inout] cc_tensor Tensor definition used by CC - /// - static void DestroyTensorDescriptor(ccTensorDescriptor_t &cc_tensor) noexcept; - - /// - /// @ingroup domi_ome - /// @brief Destroys a tensor - /// @param [inout] cc_filter cc_filter Definition of the filter used by CC - /// - static void DestroyFilterDescriptor(ccFilterDescriptor_t &cc_filter); - - /// - /// @ingroup domi_ome - /// @brief Initializing Filter Description - /// @param [in] model_filter Filter information defined in the offline model - /// @param [out] cc_filter Definition of the filter used by CC - /// @return SUCCESS success - /// @return FAILED fail - /// - static Status InitFilterDescriptor(const ge::GeTensor &model_filter, ccFilterDescriptor_t &cc_filter); /// /// @brief Extract AIPP parameters from AttrDefMap and splice them @@ -209,16 +129,7 @@ class OpUtils { /// @param [out] output Data pointer after conversion. The format is HWCK /// static void TransDataKCHW2HWCK(const void *input, int64_t K, int64_t C, int64_t H, int64_t W, void *output); - /// - /// @ingroup domi_omg - /// @brief Initialize the input and output description of the data node which is applied to filter weight in the - /// training network - /// @param [in] model_tensor input and output tensor information - /// @param [out] cc_tensor Tensor in CCE format after conversion - /// - static Status InitFilterTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccFilterDescriptor_t &cc_tensor); - static void SetTensorDescriptorAllOffsetQuantizeInfo(const GeTensorDesc &tensor, ccTensorDescriptor_t cc_tensor); static vector GetWeights(const ge::Node &node); static vector GetWeights(ge::ConstNodePtr node); static vector MutableWeights(const ge::Node &node); @@ -228,69 +139,7 @@ class OpUtils { static Status GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims); private: - friend class CceTensorDescriptor; static uint32_t GetRealDimCnt(const GeTensorDesc &tensor_desc); }; - -class CceTensorDescriptor; - -using CceTensorDescriptorPtr = std::shared_ptr; - -class CceTensorDescriptor { - public: - explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); - CceTensorDescriptor(const CceTensorDescriptor &) = delete; - CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; - - ~CceTensorDescriptor(); - - ccTensorDescriptor_t GetPtr() { return cc_tensor_; } - - /// - /// @brief Initializes the tensor based on shape information. - /// @param[in] format data permutation format - /// @param[in] data_type Data Type - /// @param[in] dim dim information - /// @return return code - /// - Status InitTensor(int32_t format, int32_t data_type, const std::vector &dims); - - Status InitTensor(int32_t format, int32_t data_type, const ge::GeShape &shape); - - /// - /// @brief get format of tensor - /// @param[out] format format of the tensor - /// @return return code - /// - Status GetFormat(ccTensorFormat_t *format); - - /// - /// @brief Obtains the size of the tensor. - /// @param[out] size size of Tensor - /// @return return code - /// - Status GetTensorSizeInBytes(uint32_t *size); - - /// - /// @brief transform tensor between 4d(NCHW) and 5d(NC1HWC0) - /// @param [in] xDesc descriptor of input tensor - /// @param [in] x point to input data in host memory - /// @param [in] dataTypeTransmode mode of data type transform - /// @param [in] yDesc descriptor of output tensor - /// @param [in|out] y point to output data in host memory - /// @param [in] ySizeInBytes size of outputData - /// @return return code - /// - static Status TransTensor(const ccTensorDescriptor_t xDesc, const void *x, const CceTensorDescriptorPtr &yDesc, - void *y, uint32_t ySizeInBytes); - - /// - /// @brief CceTensorDescriptor Static Constructor - /// @return CceTensorDescriptor smart pointer - /// - static CceTensorDescriptorPtr Create(); - - ccTensorDescriptor_t cc_tensor_ = nullptr; -}; } // namespace ge #endif // INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ diff --git a/inc/framework/common/op/op_parser_util.h b/inc/framework/common/op/op_parser_util.h index 49b4350a..43254ca9 100644 --- a/inc/framework/common/op/op_parser_util.h +++ b/inc/framework/common/op/op_parser_util.h @@ -17,7 +17,6 @@ #ifndef INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ #define INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ -#include #include #include #include @@ -31,10 +30,7 @@ const uint32_t NORMAL_OUTPUT_NUM = 1; const uint32_t NORMAL_WORKSPACE_NUM = 0; const int32_t NORMAL_1D_DIM_NUM = 1; const int32_t NORMAL_SCALE_DIM_NUM = 0; -const int NORMAL_TENSOR_FORMAT = static_cast(cce::CC_TENSOR_NC1HWC0); const int NORMAL_TENSOR_SIZE = 4; -const int NORMAL_DEVICE_DATA_TYPE = static_cast(cce::CC_DATA_HALF); -const int DEFAULT_POOLING_MODE = static_cast(cce::CC_POOLING_MAX); const uint32_t DEFAULT_REAL_DIM_CNT = 4; // const @@ -183,7 +179,6 @@ const int32_t SSD_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0; const float SSD_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; const int32_t SSD_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; const float SSD_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; -const int SSD_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast(cce::CC_BOX_CENTER_SIZE); const int32_t SSD_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; const bool SSD_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; const float SSD_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; @@ -200,7 +195,6 @@ const float REFINEDET_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; const int32_t REFINEDET_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; const float REFINEDET_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; const bool REFINEDET_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; -const int REFINEDET_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast(cce::CC_BOX_CENTER_SIZE); const int32_t REFINEDET_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; const float REFINEDET_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; const float REFINEDET_DETECTIONOUTPUT_OBJECTNESS_SCORE_DEFAULT_VALUE = 0; diff --git a/inc/framework/common/profiling/ge_profiling.h b/inc/framework/common/profiling/ge_profiling.h new file mode 100644 index 00000000..c51f837f --- /dev/null +++ b/inc/framework/common/profiling/ge_profiling.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_GE_PROFILING_H_ +#define INC_FRAMEWORK_COMMON_GE_PROFILING_H_ + +#include "ge/ge_api_error_codes.h" +#include "toolchain/prof_callback.h" + +#define MAX_DEV_NUM (64) +enum ProfCommandHandleType { + kProfCommandhandleInit = 0, + kProfCommandhandleStart, + kProfCommandhandleStop, + kProfCommandhandleFinalize, + kProfCommandhandleModelSubscribe, + kProfCommandhandleModelUnsubscribe +}; + +struct ProfCommandHandleData { + uint64_t profSwitch; + uint32_t devNums; // length of device id list + uint32_t devIdList[MAX_DEV_NUM]; + uint32_t modelId; +}; + +ge::Status RegProfCtrlCallback(MsprofCtrlCallback func); +ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); +ge::Status RegProfReporterCallback(MsprofReporterCallback func); +ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); + +#endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ diff --git a/inc/framework/common/profiling/ge_runner_profiling.h b/inc/framework/common/profiling/ge_runner_profiling.h new file mode 100644 index 00000000..d2eff767 --- /dev/null +++ b/inc/framework/common/profiling/ge_runner_profiling.h @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ +#define INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ + +#include "profiling/ge_profiling.h" + +bool IsInitialize(); + +#endif // INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ diff --git a/inc/framework/common/taskdown_common.h b/inc/framework/common/taskdown_common.h new file mode 100644 index 00000000..090e7e26 --- /dev/null +++ b/inc/framework/common/taskdown_common.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ +#define INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ + +#include "runtime/rt.h" + +namespace ge { + +const int CC_FUSION_OP_MAX = 32; + +typedef enum tagCcStatus { + CC_STATUS_SUCCESS = 0, /**< succ */ + CC_STATUS_NOT_INITIALIZED = 1, /**< not init */ + CC_STATUS_ALLOC_FAILED = 2, /**< alloc mem failed */ + CC_STATUS_BAD_PARAM = 3, /**< para check failed */ + CC_STATUS_INTERNAL_ERROR = 4, /**< internal error */ + CC_STATUS_KERNEL_ERROR = 5, /**< kernel error */ + CC_STATUS_RUNTIME_ERROR = 6, /**< runtime error */ + CC_STATUS_NOT_SUPPORTED = 7, /**< unsupport error */ + CC_STATUS_INVALID_VALUE = 7, /**< invalid value error for blas*/ + CC_STATUS_RESERVED /**< just for check */ +} ccStatus_t; + +typedef enum tagccKernelType { + CCE_AI_CORE = 0, /* cce aicore */ + CCE_AI_CPU = 1, /* cce aicpu */ + TE = 2, /* te operator*/ + CUSTOMIZED = 3, /* customized operator */ + TE_AI_CORE = 4, /* te aicore operator*/ + TE_AI_CPU = 5, /* te aicpu operator */ + AI_CPU = 6, /* aicpu */ + CUST_AI_CPU = 7, /* custom aicpu*/ + INVALID = 8, /* unknown kernel type */ +} ccKernelType; + +typedef struct tagOpContext { + ccKernelType kernelType; + uint32_t opId; + uint32_t kernelFuncId; + uint32_t opIndex; + uint32_t opCount; + uint32_t opIndex2[CC_FUSION_OP_MAX]; + bool isFlowtable; + uint16_t *argsOffset; + uint32_t argsCount; + uint64_t genDataBaseAddr; + uint64_t genDataBaseSize; + uint64_t genWeightBaseAddr; + uint64_t genWeightBaseSize; + uint64_t genVariableBaseAddr; + uint64_t genVariableBaseSize; + uint64_t l2ctrlSize; +} ccOpContext; +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 441d0757..99c2ea03 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -529,7 +529,7 @@ REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); // aicpu op for online_infer dynamic_dims REGISTER_OPTYPE_DECLARE(GETDYNAMICDIMS, "GetDynamicDims"); -enum InputMode { INPUT = 0, CONST_INPUT}; +enum InputMode { INPUT = 0, CONST_INPUT }; // Definition of the processing status enum of the process module enum ModelProcessState { @@ -605,7 +605,7 @@ static constexpr uint32_t MODEL_FILE_CHECKSUM_LENGTH = 64; /// /// @brief length of the reserved field in the model file header /// -static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 79; +static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 75; /// /// @ingroup domi_omg @@ -843,9 +843,10 @@ struct ModelFileHeader { uint32_t ops = 0; // Computing power (Kops) uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH] = {0}; // User-defined information. The value contains 32 characters uint32_t om_ir_version = 0; + uint32_t model_num = 0; uint8_t platform_version[PLATFORM_VERSION_LEN] = {0}; uint8_t platform_type = {0}; - uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0}; // Reserved field 79 + uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0}; // Reserved field 75 }; static constexpr uint8_t TARGET_TYPE_LTTE_8BIT = 0; diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index c446b983..e0904965 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -74,11 +74,22 @@ class GeGenerator { /// @param [in] op_desc: the OP description. /// @param [in] inputs: input tensors. /// @param [in] outputs: output tensors. - /// @param [in] engine_type: specific engine. - /// @param [out] model_buff: model buff of single op. + /// @param [in] engine_type: engine type. + /// @param [out] model_buff: model buff of op. /// @return SUCCESS or FAILED Status BuildSingleOpModel(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, OpEngineType engine_type, ModelBufferData &model_buff); + /// + /// @ingroup ge + /// @brief: Build single Op into model buff. + /// @param [in] op_desc: the OP description. + /// @param [in] inputs: input tensors. + /// @param [in] outputs: output tensors. + /// @param [in] graph_name: graph name. + /// @param [out] graph: graph of single op. + /// @return SUCCESS or FAILED + Status BuildSingleOpGraph(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, + std::string graph_name, Graph &graph); private: Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, diff --git a/metadef b/metadef deleted file mode 160000 index 4176fab0..00000000 --- a/metadef +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4176fab0cb2fd4f8794061916878983afb75c8da diff --git a/metadef/CMakeLists.txt b/metadef/CMakeLists.txt new file mode 100755 index 00000000..5ad573a6 --- /dev/null +++ b/metadef/CMakeLists.txt @@ -0,0 +1,59 @@ +cmake_minimum_required(VERSION 3.14) +project (MetaDef[CXX]) + +set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}) + +if (DEFINED ENV{D_PKG_SERVER}) + set(METADEF_PB_PKG $ENV{D_PKG_SERVER}) +endif() + +option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) + +if (ENABLE_OPEN_SRC) + set(HI_PYTHON python3) + + include(cmake/external_libs/protobuf_shared.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/external_libs/json.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) + + if(DEFINED ENV{D_LINK_PATH}) + # D_LINK_PATH is set + set(GE_LIB_PATH $ENV{D_LINK_PATH}) + set(GE_SYS_ARCH "") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + # x86 ubuntu + set(GE_SYS_ARCH "x86_64") + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + # arm euleros + set(GE_SYS_ARCH "aarch64") + else() + message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") + endif() + set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + find_module(slog libslog.so ${GE_LIB_PATH}) + find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) + find_module(error_manager liberror_manager.so ${GE_LIB_PATH}) + else() + if(DEFINED ENV{ASCEND_CUSTOM_PATH}) + set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) + else() + set(ASCEND_DIR /usr/local/Ascend) + endif() + + set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) + + find_module(slog libslog.so ${ASCEND_ATC_DIR}) + find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) + find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) + endif() + +endif() + +add_subdirectory(graph) +if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) + add_subdirectory(register) +endif () + diff --git a/metadef/LICENSE b/metadef/LICENSE new file mode 100644 index 00000000..29f81d81 --- /dev/null +++ b/metadef/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/metadef/NOTICE b/metadef/NOTICE new file mode 100644 index 00000000..9450f845 --- /dev/null +++ b/metadef/NOTICE @@ -0,0 +1,2 @@ +Ascend Metadata Definition +Copyright 2020 Huawei Technologies Co., Ltd diff --git a/metadef/OWNERS b/metadef/OWNERS new file mode 100644 index 00000000..8c42c84f --- /dev/null +++ b/metadef/OWNERS @@ -0,0 +1,7 @@ +approvers: +- ji_chen +- wqtshg +- ljl0711 +reviewers: +- xchu42 +- sheng-nan diff --git a/metadef/README.en.md b/metadef/README.en.md new file mode 100644 index 00000000..e65dd29c --- /dev/null +++ b/metadef/README.en.md @@ -0,0 +1,9 @@ +# metadef + +## Introduction + +Ascend metadata definition. + +## License + +[Apache License 2.0](LICENSE) diff --git a/metadef/README.md b/metadef/README.md new file mode 100644 index 00000000..276936d0 --- /dev/null +++ b/metadef/README.md @@ -0,0 +1,9 @@ +# metadef + +## 介绍 + +昇腾元数据定义 + +## 许可证 + +[Apache License 2.0](LICENSE) diff --git a/metadef/build.sh b/metadef/build.sh new file mode 100644 index 00000000..d35d67d9 --- /dev/null +++ b/metadef/build.sh @@ -0,0 +1,210 @@ +#!/bin/bash +# Copyright 2019-2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +set -e +BASEPATH=$(cd "$(dirname $0)"; pwd) +OUTPUT_PATH="${BASEPATH}/output" +export BUILD_PATH="${BASEPATH}/build/" + +# print usage message +usage() +{ + echo "Usage:" + echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off]" + echo "" + echo "Options:" + echo " -h Print usage" + echo " -u Only compile ut, not execute" + echo " -s Build st" + echo " -j[n] Set the number of threads used for building Metadef, default is 8" + echo " -t Build and execute ut" + echo " -c Build ut with coverage tag" + echo " -v Display build command" + echo " -S Enable enable download cmake compile dependency from gitee , default off" + echo "to be continued ..." +} + +# check value of input is 'on' or 'off' +# usage: check_on_off arg_value arg_name +check_on_off() +{ + if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then + echo "Invalid value $1 for option -$2" + usage + exit 1 + fi +} + +# parse and set options +checkopts() +{ + VERBOSE="" + THREAD_NUM=8 + # ENABLE_METADEF_UT_ONLY_COMPILE="off" + ENABLE_GE_UT="off" + ENABLE_GE_ST="off" + ENABLE_GE_COV="off" + GE_ONLY="on" + ENABLE_GITEE="off" + # Process the options + while getopts 'ustchj:vS:' opt + do + OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') + case "${opt}" in + u) + # ENABLE_GE_UT_ONLY_COMPILE="on" + ENABLE_GE_UT="on" + GE_ONLY="off" + ;; + s) + ENABLE_GE_ST="on" + ;; + t) + ENABLE_GE_UT="on" + GE_ONLY="off" + ;; + c) + ENABLE_GE_COV="on" + GE_ONLY="off" + ;; + h) + usage + exit 0 + ;; + j) + THREAD_NUM=$OPTARG + ;; + v) + VERBOSE="VERBOSE=1" + ;; + S) + check_on_off $OPTARG S + ENABLE_GITEE="$OPTARG" + echo "enable download from gitee" + ;; + *) + echo "Undefined option: ${opt}" + usage + exit 1 + esac + done +} +checkopts "$@" + +mk_dir() { + local create_dir="$1" # the target to make + + mkdir -pv "${create_dir}" + echo "created ${create_dir}" +} + +# Meatdef build start +echo "---------------- Metadef build start ----------------" + +# create build path +build_metadef() +{ + echo "create build directory and build Metadef"; + mk_dir "${BUILD_PATH}" + cd "${BUILD_PATH}" + CMAKE_ARGS="-DBUILD_PATH=$BUILD_PATH -DGE_ONLY=$GE_ONLY" + + if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_COV=ON" + fi + + if [[ "X$ENABLE_GE_UT" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_UT=ON" + fi + + + if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" + fi + + if [[ "X$ENABLE_GITEE" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" + fi + + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}" + echo "${CMAKE_ARGS}" + cmake ${CMAKE_ARGS} .. + if [ 0 -ne $? ] + then + echo "execute command: cmake ${CMAKE_ARGS} .. failed." + return 1 + fi + make ${VERBOSE} -j${THREAD_NUM} && make install + if [ 0 -ne $? ] + then + echo "execute command: make ${VERBOSE} -j${THREAD_NUM} && make install failed." + return 1 + fi + echo "Metadef build success!" +} + +g++ -v +mk_dir ${OUTPUT_PATH} +build_metadef || { echo "Metadef build failed."; return; } +echo "---------------- Metadef build finished ----------------" +rm -f ${OUTPUT_PATH}/libgmock*.so +rm -f ${OUTPUT_PATH}/libgtest*.so +rm -f ${OUTPUT_PATH}/lib*_stub.so + +chmod -R 750 ${OUTPUT_PATH} +find ${OUTPUT_PATH} -name "*.so*" -print0 | xargs -0 chmod 500 + +echo "---------------- Metadef output generated ----------------" + +# generate output package in tar form, including ut/st libraries/executables +generate_package() +{ + cd "${BASEPATH}" + + METADEF_LIB_PATH="lib" + ACL_PATH="acllib/lib64" + FWK_PATH="fwkacllib/lib64" + ATC_PATH="atc/lib64" + + COMMON_LIB=("libgraph.so" "libregister.so") + + rm -rf ${OUTPUT_PATH:?}/${FWK_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ACL_PATH}/ + rm -rf ${OUTPUT_PATH:?}/${ATC_PATH}/ + + mk_dir "${OUTPUT_PATH}/${FWK_PATH}" + mk_dir "${OUTPUT_PATH}/${ATC_PATH}" + mk_dir "${OUTPUT_PATH}/${ACL_PATH}" + + find output/ -name metadef_lib.tar -exec rm {} \; + + cd "${OUTPUT_PATH}" + + for lib in "${COMMON_LIB[@]}"; + do + find ${OUTPUT_PATH}/${METADEF_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${FWK_PATH} \; + find ${OUTPUT_PATH}/${METADEF_LIB_PATH} -maxdepth 1 -name "$lib" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + done + + find ${OUTPUT_PATH}/${METADEF_LIB_PATH} -maxdepth 1 -name "libc_sec.so" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; + + tar -cf metadef_lib.tar fwkacllib atc +} + +if [[ "X$ENABLE_GE_UT" = "Xoff" ]]; then + generate_package +fi +echo "---------------- Metadef package archive generated ----------------" diff --git a/metadef/cmake/FindModule.cmake b/metadef/cmake/FindModule.cmake new file mode 100644 index 00000000..11af1598 --- /dev/null +++ b/metadef/cmake/FindModule.cmake @@ -0,0 +1,22 @@ +#[[ + module - the name of export imported target + name - find the library name + path - find the library path +#]] +function(find_module module name path) + if (TARGET ${module}) + return() + endif() + find_library(${module}_LIBRARY_DIR NAMES ${name} NAMES_PER_DIR PATHS ${path} + PATH_SUFFIXES lib + ) + + message(STATUS "find ${name} location ${${module}_LIBRARY_DIR}") + if ("${${module}_LIBRARY_DIR}" STREQUAL "${module}_LIBRARY_DIR-NOTFOUND") + message(FATAL_ERROR "${name} not found in ${path}") + endif() + add_library(${module} SHARED IMPORTED) + set_target_properties(${module} PROPERTIES + IMPORTED_LOCATION ${${module}_LIBRARY_DIR} + ) +endfunction() diff --git a/metadef/cmake/external_libs/gflags.cmake b/metadef/cmake/external_libs/gflags.cmake new file mode 100755 index 00000000..34493e24 --- /dev/null +++ b/metadef/cmake/external_libs/gflags.cmake @@ -0,0 +1,47 @@ +if (HAVE_GFLAGS) + return() +endif() + +include(ExternalProject) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${METADEF_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") + set(MD5 "") +else() + set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") + set(MD5 "") +endif () + +set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") +ExternalProject_Add(gflags_build + URL ${REQ_URL} + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${METADEF_DIR}/../../third_party/gflags/src/gflags-2.2.2 + CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gflags_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) + +set(GFLAGS_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gflags) + +add_library(gflags_static STATIC IMPORTED) + +set_target_properties(gflags_static PROPERTIES + IMPORTED_LOCATION ${GFLAGS_PKG_DIR}/lib/libgflags.a +) + +add_library(gflags INTERFACE) +target_include_directories(gflags INTERFACE ${GFLAGS_PKG_DIR}/include) +target_link_libraries(gflags INTERFACE gflags_static) + +add_dependencies(gflags gflags_build) + +#set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") +set(HAVE_GFLAGS TRUE) diff --git a/metadef/cmake/external_libs/json.cmake b/metadef/cmake/external_libs/json.cmake new file mode 100755 index 00000000..c4a52843 --- /dev/null +++ b/metadef/cmake/external_libs/json.cmake @@ -0,0 +1,33 @@ +if (HAVE_JSON) + return() +endif() + +include(ExternalProject) + +set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) +#if (ENABLE_GITEE) +# set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") +# set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") +# set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") +#else() +set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") +set(MD5 "0dc903888211db3a0f170304cd9f3a89") +set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) +#endif () +ExternalProject_Add(json_build + URL ${REQ_URL} + #URL /home/txd/workspace/cloud_code/pkg/include.zip + SOURCE_DIR ${JSON_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + EXCLUDE_FROM_ALL TRUE +) + + +add_library(json INTERFACE) +target_include_directories(json INTERFACE ${JSON_INCLUDE_DIR}) +add_dependencies(json json_build) + +#set(HAVE_JSON TRUE CACHE BOOL "json build add") +set(HAVE_JSON TRUE) diff --git a/metadef/cmake/external_libs/onnx.cmake b/metadef/cmake/external_libs/onnx.cmake new file mode 100755 index 00000000..9dadb544 --- /dev/null +++ b/metadef/cmake/external_libs/onnx.cmake @@ -0,0 +1,37 @@ +include(ExternalProject) + +#set(ONNX_SRC_DIR /home/txd/workspace/cloud_code/graphengine/build/graphengine/open_source/onnx) +#set(ONNX_PROTO ${ONNX_SRC_DIR}/onnx/onnx.proto) +set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) +set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) +file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) + +if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") + set(MD5 "1bdbcecdd68ea8392630467646776e02") +else() + set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz") + set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") +endif () + +ExternalProject_Add(onnx + URL ${REQ_URL} + #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz + #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 + #SOURCE_DIR ${ONNX_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + #INSTALL_COMMAND "" + INSTALL_COMMAND ${CMAKE_COMMAND} -E copy /onnx/onnx.proto ${ONNX_PROTO_FILE} + #BUILD_ALWAYS TRUE + EXCLUDE_FROM_ALL TRUE +) + +macro(onnx_protobuf_generate comp c_var h_var) + add_custom_command(OUTPUT ${ONNX_PROTO_FILE} + DEPENDS onnx + ) + ge_protobuf_generate(${comp} ${c_var} ${h_var} ${ONNX_PROTO_FILE}) +endmacro() + + diff --git a/metadef/cmake/external_libs/protobuf_shared.cmake b/metadef/cmake/external_libs/protobuf_shared.cmake new file mode 100755 index 00000000..d21c686c --- /dev/null +++ b/metadef/cmake/external_libs/protobuf_shared.cmake @@ -0,0 +1,68 @@ +if (HAVE_PROTOBUF) + return() +endif() + +include(ExternalProject) +include(GNUInstallDirs) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${METADEF_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() +if (METADEF_PB_PKG) + set(REQ_URL "${METADEF_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") +else() + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") + endif () +endif() + +set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") +set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") +ExternalProject_Add(protobuf_build + URL ${REQ_URL} + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -Dprotobuf_WITH_ZLIB=OFF + -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_LINKER=${CMAKE_LINKER} + -DCMAKE_AR=${CMAKE_AR} + -DCMAKE_RANLIB=${CMAKE_RANLIB} + -DLIB_PREFIX=ascend_ + -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protobuf /cmake + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) +include(GNUInstallDirs) + +set(PROTOBUF_SHARED_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf) + +add_library(ascend_protobuf SHARED IMPORTED) + +file(MAKE_DIRECTORY ${PROTOBUF_SHARED_PKG_DIR}/include) + +set_target_properties(ascend_protobuf PROPERTIES + IMPORTED_LOCATION ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.so +) + +target_include_directories(ascend_protobuf INTERFACE ${PROTOBUF_SHARED_PKG_DIR}/include) + +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so.3.8.0.0 OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR}) +install(FILES ${PROTOBUF_SHARED_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/ascend_protobuf.so OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR}) + +add_dependencies(ascend_protobuf protobuf_build) + +#set(HAVE_PROTOBUF TRUE CACHE BOOL "protobuf build add") +set(HAVE_PROTOBUF TRUE) diff --git a/metadef/cmake/external_libs/protobuf_static.cmake b/metadef/cmake/external_libs/protobuf_static.cmake new file mode 100755 index 00000000..8f0fc8a5 --- /dev/null +++ b/metadef/cmake/external_libs/protobuf_static.cmake @@ -0,0 +1,55 @@ +include(ExternalProject) +include(GNUInstallDirs) +#set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +if (METADEF_PB_PKG) + set(REQ_URL "${METADEF_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") +else() + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") + endif () +endif() + +set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") +set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") +set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) +ExternalProject_Add(protobuf_static_build + URL ${REQ_URL} + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} + -DCMAKE_LINKER=${CMAKE_LINKER} + -DCMAKE_AR=${CMAKE_AR} + -DCMAKE_RANLIB=${CMAKE_RANLIB} + -Dprotobuf_WITH_ZLIB=OFF + -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_STATIC_PKG_DIR} /cmake + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) +include(GNUInstallDirs) + +add_library(ascend_protobuf_static_lib STATIC IMPORTED) + +set_target_properties(ascend_protobuf_static_lib PROPERTIES + IMPORTED_LOCATION ${PROTOBUF_STATIC_PKG_DIR}/${CMAKE_INSTALL_LIBDIR}/libascend_protobuf.a +) + +add_library(ascend_protobuf_static INTERFACE) +target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) +target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) + +add_dependencies(ascend_protobuf_static protobuf_static_build) diff --git a/metadef/cmake/external_libs/protoc.cmake b/metadef/cmake/external_libs/protoc.cmake new file mode 100755 index 00000000..d38b3c93 --- /dev/null +++ b/metadef/cmake/external_libs/protoc.cmake @@ -0,0 +1,117 @@ +if (HAVE_PROTOC) + return() +endif() + +include(ExternalProject) +include(GNUInstallDirs) +#set(CMAKE_INSTALL_PREFIX ${METADEF_DIR}/output) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${METADEF_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + + +if (METADEF_PB_PKG) + set(REQ_URL "${METADEF_PB_PKG}/libs/protobuf/v3.8.0.tar.gz") +else() + if (ENABLE_GITEE) + set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") + set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") + else() + set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") + set(MD5 "3d9e32700639618a4d2d342c99d4507a") + endif() +endif() + +set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") +set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") +ExternalProject_Add(protoc_build + URL ${REQ_URL} + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 + CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc /cmake + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) + +set(PROTOC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protoc) + +set(protoc_EXECUTABLE ${PROTOC_PKG_DIR}/${CMAKE_INSTALL_BINDIR}/protoc) + +function(protobuf_generate comp c_var h_var) + if(NOT ARGN) + message(SEND_ERROR "Error: protobuf_generate() called without any proto files") + return() + endif() + set(${c_var}) + set(${h_var}) + + foreach(file ${ARGN}) + get_filename_component(abs_file ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + get_filename_component(parent_subdir ${file_dir} NAME) + + if("${parent_subdir}" STREQUAL "proto") + set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) + else() + set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) + endif() + list(APPEND ${c_var} "${proto_output_path}/${file_name}.pb.cc") + list(APPEND ${h_var} "${proto_output_path}/${file_name}.pb.h") + + add_custom_command( + OUTPUT "${proto_output_path}/${file_name}.pb.cc" "${proto_output_path}/${file_name}.pb.h" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" + COMMAND ${protoc_EXECUTABLE} -I${file_dir} --cpp_out=${proto_output_path} ${abs_file} + DEPENDS protoc_build ${abs_file} + COMMENT "Running C++ protocol buffer compiler on ${file}" VERBATIM ) + endforeach() + + set_source_files_properties(${${c_var}} ${${h_var}} PROPERTIES GENERATED TRUE) + set(${c_var} ${${c_var}} PARENT_SCOPE) + set(${h_var} ${${h_var}} PARENT_SCOPE) + +endfunction() + +function(protobuf_generate_py comp py_var) + if(NOT ARGN) + message(SEND_ERROR "Error: protobuf_generate_py() called without any proto files") + return() + endif() + set(${py_var}) + + foreach(file ${ARGN}) + get_filename_component(abs_file ${file} ABSOLUTE) + get_filename_component(file_name ${file} NAME_WE) + get_filename_component(file_dir ${abs_file} PATH) + get_filename_component(parent_subdir ${file_dir} NAME) + + if("${parent_subdir}" STREQUAL "proto") + set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto) + else() + set(proto_output_path ${CMAKE_BINARY_DIR}/proto/${comp}/proto/${parent_subdir}) + endif() + list(APPEND ${py_var} "${proto_output_path}/${file_name}_pb2.py") + + add_custom_command( + OUTPUT "${proto_output_path}/${file_name}_pb2.py" + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory "${proto_output_path}" + COMMAND ${protoc_EXECUTABLE} -I${file_dir} --python_out=${proto_output_path} ${abs_file} + DEPENDS protoc_build ${abs_file} + COMMENT "Running PYTHON protocol buffer compiler on ${file}" VERBATIM ) + endforeach() + + set_source_files_properties(${${py_var}} PROPERTIES GENERATED TRUE) + set(${py_var} ${${py_var}} PARENT_SCOPE) + +endfunction() + + +#set(HAVE_PROTOC TRUE CACHE BOOL "protoc build add") +set(HAVE_PROTOC TRUE) diff --git a/metadef/cmake/external_libs/securec.cmake b/metadef/cmake/external_libs/securec.cmake new file mode 100755 index 00000000..a5bbbf80 --- /dev/null +++ b/metadef/cmake/external_libs/securec.cmake @@ -0,0 +1,62 @@ +if (HAVE_C_SEC) + return() +endif() + +include(ExternalProject) + +if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR + (${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) + set(CMAKE_INSTALL_PREFIX ${METADEF_DIR}/output CACHE STRING "path for install()" FORCE) + message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") +endif() + +ExternalProject_Add(c_sec_build + URL https://gitee.com/openeuler/libboundscheck/repository/archive/v1.1.10.tar.gz + #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz + #SOURCE_DIR ${METADEF_DIR}/../libc_sec + PATCH_COMMAND patch -p1 < ${METADEF_DIR}/third_party/patch/securec/0001-add-securec-cmake-script.patch + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_LINKER=${CMAKE_LINKER} + -DCMAKE_AR=${CMAKE_AR} + -DCMAKE_RANLIB=${CMAKE_RANLIB} + -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/c_sec + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + EXCLUDE_FROM_ALL TRUE +) + +set(C_SEC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/c_sec) + +add_library(c_sec SHARED IMPORTED) + +file(MAKE_DIRECTORY ${C_SEC_PKG_DIR}/include) + +set_target_properties(c_sec PROPERTIES + IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.so +) + +target_include_directories(c_sec INTERFACE ${C_SEC_PKG_DIR}/include) + +add_dependencies(c_sec c_sec_build) + +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(FILES ${C_SEC_PKG_DIR}/lib/libc_sec.so OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR}) + +add_library(c_sec_static_lib STATIC IMPORTED) +set_target_properties(c_sec_static_lib PROPERTIES + IMPORTED_LOCATION ${C_SEC_PKG_DIR}/lib/libc_sec.a +) + +add_library(c_sec_static INTERFACE) +target_include_directories(c_sec_static INTERFACE ${C_SEC_PKG_DIR}/include) +target_link_libraries(c_sec_static INTERFACE c_sec_static_lib) + +add_dependencies(c_sec_static c_sec_build) + +#set(HAVE_C_SEC TRUE CACHE BOOL "c_sec build add") +set(HAVE_C_SEC TRUE) diff --git a/metadef/cmake/intf_pub_android.cmake b/metadef/cmake/intf_pub_android.cmake new file mode 100755 index 00000000..153d5764 --- /dev/null +++ b/metadef/cmake/intf_pub_android.cmake @@ -0,0 +1,52 @@ + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + -fstack-protector-strong +) +target_compile_definitions(intf_pub INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> + WIN64=1 + LINUX=0 +) +target_link_options(intf_pub INTERFACE + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) + +add_library(intf_ccec INTERFACE) +target_compile_options(intf_ccec INTERFACE + -mcpu=cortex-a73 + --target=aarch64-linux-android29 + --sysroot=${HCC_PATH}/../sysroot + -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x + -Wall + -fPIC + -fstack-protector-strong +) +target_compile_definitions(intf_ccec INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> +) + +target_link_options(intf_ccec INTERFACE + -mcpu=cortex-a73 + --target=aarch64-linux-android29 + --sysroot=${HCC_PATH}/../sysroot + -L${HCC_PATH}/../lib/gcc/aarch64-linux-android/4.9.x + -Wl,-cce-host-android + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) + diff --git a/metadef/cmake/intf_pub_linux.cmake b/metadef/cmake/intf_pub_linux.cmake new file mode 100755 index 00000000..b8f346cf --- /dev/null +++ b/metadef/cmake/intf_pub_linux.cmake @@ -0,0 +1,33 @@ +if (HAVE_PUB) + return() +endif() + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + $,-fstack-protector-all,-fstack-protector-strong> + $<$:-std=c++11> +) +target_compile_definitions(intf_pub INTERFACE + _GLIBCXX_USE_CXX11_ABI=0 + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> + WIN64=1 + LINUX=0 +) +target_link_options(intf_pub INTERFACE + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) +target_link_libraries(intf_pub INTERFACE + -lpthread +) + +#set(HAVE_PUB TRUE CACHE BOOL "pub add") +set(HAVE_PUB TRUE) diff --git a/metadef/cmake/intf_pub_windows.cmake b/metadef/cmake/intf_pub_windows.cmake new file mode 100755 index 00000000..19e37283 --- /dev/null +++ b/metadef/cmake/intf_pub_windows.cmake @@ -0,0 +1,24 @@ + +add_library(intf_pub INTERFACE) + +target_compile_options(intf_pub INTERFACE + -Wall + -fPIC + $,-fstack-protector-all,-fstack-protector-strong> + $<$:-std=c++11> +) +target_compile_definitions(intf_pub INTERFACE + $<$:_GLIBCXX_USE_CXX11_ABI=0> + OS_TYPE=WIN64 + WIN64=1 + LINUX=0 + $<$:CFG_BUILD_NDEBUG> + $<$:CFG_BUILD_DEBUG> +) +target_link_options(intf_pub INTERFACE + $<$:-Wl,--build-id=none> +) +target_link_directories(intf_pub INTERFACE +) +target_link_libraries(intf_pub INTERFACE +) diff --git a/metadef/graph/CMakeLists.txt b/metadef/graph/CMakeLists.txt new file mode 100644 index 00000000..a56becc6 --- /dev/null +++ b/metadef/graph/CMakeLists.txt @@ -0,0 +1,422 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/om.proto" + "${METADEF_DIR}/proto/ge_ir.proto" + "${METADEF_DIR}/proto/insert_op.proto" + "${METADEF_DIR}/proto/task.proto" + "${METADEF_DIR}/proto/dump_task.proto" + "${METADEF_DIR}/proto/fwk_adapter.proto" + "${METADEF_DIR}/proto/op_mapping_info.proto" + "${METADEF_DIR}/proto/proto_inner/ge_onnx.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "anchor.cc" + "ge_attr_value.cc" + "attr_value.cc" + "buffer.cc" + "compute_graph.cc" + "ascend_string.cc" + "gnode.cc" + "graph.cc" + "inference_context.cc" + "shape_refiner.cc" + "format_refiner.cc" + "ref_relation.cc" + "model.cc" + "model_serialize.cc" + "node.cc" + "op_desc.cc" + "operator.cc" + "operator_factory.cc" + "operator_factory_impl.cc" + "ge_attr_define.cc" + "ge_tensor.cc" + "detail/attributes_holder.cc" + "utils/anchor_utils.cc" + "utils/tuning_utils.cc" + "utils/graph_utils.cc" + "utils/ge_ir_utils.cc" + "utils/node_utils.cc" + "utils/op_desc_utils.cc" + "utils/type_utils.cc" + "utils/tensor_utils.cc" + "tensor.cc" + "debug/graph_debug.cc" + "opsproto/opsproto_manager.cc" + "../ops/op_imp.cpp" + "option/ge_context.cc" + "option/ge_local_context.cc" + "runtime_inference_context.cc" + "${METADEF_DIR}/third_party/transformer/src/axis_util.cpp" + "${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cpp" + "utils/transformer_utils.cc" +) + +if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) +######### libgraph.so ############# +add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS}) + +target_compile_options(graph PRIVATE + -O2 + $<$,$>:-fexceptions> + $<$,$>: -Wno-deprecated-declarations> +) + +target_compile_definitions(graph PRIVATE + $<$,$>:FMK_SUPPORT_DUMP> + google=ascend_private + $<$:ONLY_COMPILE_OPEN_SRC> +) + +target_include_directories(graph PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge/proto + ${METADEF_DIR} + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + #### yellow zone #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../ops/built-in/op_proto/inc + ${METADEF_DIR}/../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/../libc_sec/include + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + ${METADEF_DIR}/../../ops/built-in/op_proto/inc + ${METADEF_DIR}/../../cann/ops/built-in/op_proto/inc + #### temp in ge #### + ${METADEF_DIR}/../../graphengine/inc + ${METADEF_DIR}/../../graphengine/inc/framework + ${METADEF_DIR}/../../graphengine/inc/external + ${METADEF_DIR}/../../inc + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${METADEF_DIR}/../third_party/fwkacllib/inc/ops + ${METADEF_DIR}/../third_party/fwkacllib/inc + #### blue independent compile ##### + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/ge/inc + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc + ${METADEF_DIR}/third_party/fwkacllib/inc/ops + ${METADEF_DIR}/third_party + ${METADEF_DIR}/third_party/transformer/inc +) + +target_link_libraries(graph PRIVATE + $ + static_mmpa + -Wl,--no-as-needed + ascend_protobuf + c_sec + slog + error_manager + -Wl,--as-needed + $<$>:-lrt> + -ldl +) + +######### libgraph.a ############# +add_library(graph_static STATIC ${SRC_LIST} ${PROTO_SRCS}) + +target_compile_options(graph_static PRIVATE + $<$:-O2 -fPIC> + $<$,$>:-fexceptions> + $<$,$>: -Wno-deprecated-declarations> + $<$,$>:/MTd> + $<$,$>:/MT> +) + +target_compile_definitions(graph_static PRIVATE + $<$,$>:FMK_SUPPORT_DUMP> + google=ascend_private + $<$:ONLY_COMPILE_OPEN_SRC> + $,OS_TYPE=WIN,OS_TYPE=0> + $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + LOG_CPP +) + +target_include_directories(graph_static PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge/proto + ${METADEF_DIR} + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + #### yellow zone #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../ops/built-in/op_proto/inc + ${METADEF_DIR}/../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/../libc_sec/include + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + ${METADEF_DIR}/../../ops/built-in/op_proto/inc + ${METADEF_DIR}/../../cann/ops/built-in/op_proto/inc + #### temp in ge #### + ${METADEF_DIR}/../../graphengine/inc + ${METADEF_DIR}/../../graphengine/inc/framework + ${METADEF_DIR}/../../graphengine/inc/external + ${METADEF_DIR}/../../inc + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${METADEF_DIR}/../third_party/fwkacllib/inc + ${METADEF_DIR}/../third_party/fwkacllib/inc/ops + #### blue independent compile ##### + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/ge/inc + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc + ${METADEF_DIR}/third_party/fwkacllib/inc/ops + ${METADEF_DIR}/third_party + ${METADEF_DIR}/third_party/transformer/inc +) + +target_link_libraries(graph_static PRIVATE + $ + ascend_protobuf + c_sec + $<$>:-lrt> + -ldl +) + +set_target_properties(graph_static PROPERTIES + WINDOWS_EXPORT_ALL_SYMBOLS TRUE + OUTPUT_NAME $,libgraph,graph> +) + + +############################################################## +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_attr_value.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_graph.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_operator.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_operator_factory.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_tensor.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_inference_context.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ascend_string.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_gnode.cc + COMMAND echo "Generating stub files." + && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${METADEF_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} + && mv attr_value.cc stub_attr_value.cc + && mv graph.cc stub_graph.cc + && mv operator.cc stub_operator.cc + && mv operator_factory.cc stub_operator_factory.cc + && mv tensor.cc stub_tensor.cc + && mv inference_context.cc stub_inference_context.cc + && mv ascend_string.cc stub_ascend_string.cc + && mv gnode.cc stub_gnode.cc + && echo "Generating stub files end." +) + +add_custom_target(graph_stub + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/stub_attr_value.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_graph.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_operator.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_operator_factory.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_tensor.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_inference_context.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_ascend_string.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_gnode.cc +) + +############################################################# + + +############ stub/libgraph.so ############ +add_library(atc_stub_graph SHARED + stub_graph.cc + stub_operator.cc + stub_operator_factory.cc + stub_tensor.cc + stub_attr_value.cc + stub_ascend_string.cc + stub_gnode.cc +) +add_dependencies(atc_stub_graph graph_stub) + +target_include_directories(atc_stub_graph PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_BINARY_DIR} + ${METADEF_DIR} + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + #### yellow zone #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../ops/built-in/op_proto/inc + ${METADEF_DIR}/../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/../libc_sec/include + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + ${METADEF_DIR}/../../ops/built-in/op_proto/inc + ${METADEF_DIR}/../../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/third_party + ${METADEF_DIR}/third_party/transformer/inc +) + +target_link_libraries(atc_stub_graph PRIVATE + $ +) + +set_target_properties(atc_stub_graph PROPERTIES + OUTPUT_NAME graph + LIBRARY_OUTPUT_DIRECTORY atc_stub +) + +############ fwk_stub/libgraph.so ############ +add_library(fwk_stub_graph SHARED + stub_graph.cc + stub_operator.cc + stub_operator_factory.cc + stub_tensor.cc + stub_attr_value.cc + stub_inference_context.cc + stub_ascend_string.cc + stub_gnode.cc +) + +add_dependencies(fwk_stub_graph graph_stub) + +target_include_directories(fwk_stub_graph PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_BINARY_DIR} + ${METADEF_DIR} + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + #### yellow zone #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../ops/built-in/op_proto/inc + ${METADEF_DIR}/../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/../libc_sec/include + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + ${METADEF_DIR}/../../ops/built-in/op_proto/inc + ${METADEF_DIR}/../../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/third_party + ${METADEF_DIR}/third_party/transformer/inc +) + +target_link_libraries(fwk_stub_graph PRIVATE + $ +) + +set_target_properties(fwk_stub_graph PROPERTIES + OUTPUT_NAME graph + LIBRARY_OUTPUT_DIRECTORY fwk_stub +) + +else () +######### libgraph.so w/static protobuf ############# +add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS}) + +target_compile_options(graph PRIVATE + -O2 + $<$:-fexceptions> + $<$,$>: -Wno-deprecated-declarations> + ) + +target_compile_definitions(graph PRIVATE + $<$:FMK_SUPPORT_DUMP> + $<$:ONLY_COMPILE_OPEN_SRC> + google=ascend_private + ) + +target_include_directories(graph PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + ${CMAKE_BINARY_DIR}/proto/ge/proto + ${METADEF_DIR} + ${METADEF_DIR}/graph + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/../third_party/fwkacllib/inc/ops + ${METADEF_DIR}/../third_party/fwkacllib/inc + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + ${METADEF_DIR}/../../ops/built-in/op_proto/inc + ${METADEF_DIR}/../../cann/ops/built-in/op_proto/inc + ${METADEF_DIR}/third_party + ${METADEF_DIR}/third_party/transformer/inc + ) + +target_link_libraries(graph PRIVATE + $ + ascend_protobuf_static + static_mmpa + -Wl,--no-as-needed + c_sec + slog + error_manager + -Wl,--as-needed + -lrt + -ldl + ) +endif () + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS graph OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) +if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) +install(TARGETS atc_stub_graph OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub +) + +install(TARGETS fwk_stub_graph OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/fwk_stub +) +endif () diff --git a/metadef/graph/anchor.cc b/metadef/graph/anchor.cc new file mode 100644 index 00000000..997bd741 --- /dev/null +++ b/metadef/graph/anchor.cc @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/anchor.h" +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/node.h" + +namespace ge { +Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {} + +bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf(), type) == 0; } + +size_t Anchor::GetPeerAnchorsSize() const { + return peer_anchors_.size(); +} + +Anchor::Vistor Anchor::GetPeerAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + ret.push_back(anchor.lock()); + } + return Anchor::Vistor(shared_from_this(), ret); +} + +AnchorPtr Anchor::GetFirstPeerAnchor() const { + if (peer_anchors_.empty()) { + return nullptr; + } else { + return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); + } +} + +NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); } + +void Anchor::UnlinkAll() noexcept { + if (!peer_anchors_.empty()) { + do { + auto peer_anchor_ptr = peer_anchors_.begin()->lock(); + if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { + GELOGW("unlink peer_anchor_ptr failed."); + } + } while (!peer_anchors_.empty()); + } +} + +graphStatus Anchor::Unlink(const AnchorPtr &peer) { + if (peer == nullptr) { + GELOGE(GRAPH_FAILED, "peer anchor is invalid."); + return GRAPH_FAILED; + } + auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + return peer->Equal(anchor); + }); + + GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); + + auto it_peer = + std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr &an) { + auto anchor = an.lock(); + return Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); + + (void)peer_anchors_.erase(it); + (void)peer->peer_anchors_.erase(it_peer); + return GRAPH_SUCCESS; +} + +graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { + GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr"); + GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr"); + GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr"); + auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + return old_peer->Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer"); + + auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(), + [this](const std::weak_ptr &an) { + auto anchor = an.lock(); + return Equal(anchor); + }); + + GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED, + "old_peer is not connected to this anchor"); + *this_it = first_peer; + first_peer->peer_anchors_.push_back(shared_from_this()); + *old_it = second_peer; + second_peer->peer_anchors_.push_back(old_peer); + return GRAPH_SUCCESS; +} + +bool Anchor::IsLinkedWith(const AnchorPtr &peer) { + auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { + auto anchor = an.lock(); + GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr"); + return peer->Equal(anchor); + }); + return (it != peer_anchors_.end()); +} + +int Anchor::GetIdx() const { return idx_; } + +void Anchor::SetIdx(int index) { idx_ = index; } + +DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} + +bool DataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return Anchor::IsTypeOf(type); +} + +InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} + +OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { + if (peer_anchors_.empty()) { + return nullptr; + } else { + return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); + } +} + +graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { + // InDataAnchor must be only linkfrom once + if (src == nullptr || !peer_anchors_.empty()) { + GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(src); + src->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool InDataAnchor::Equal(AnchorPtr anchor) const { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor); + if (in_data_anchor != nullptr) { + if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) { + return true; + } + } + return false; +} + +bool InDataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return DataAnchor::IsTypeOf(type); +} + +OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} + +OutDataAnchor::Vistor OutDataAnchor::GetPeerInDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr) { + ret.push_back(in_data_anchor); + } + } + return OutDataAnchor::Vistor(shared_from_this(), ret); +} + +uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { + uint32_t out_nums = 0; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) { + out_nums++; + } + } + return out_nums; +} + +OutDataAnchor::Vistor OutDataAnchor::GetPeerInControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_control_anchor != nullptr) { + ret.push_back(in_control_anchor); + } + } + return OutDataAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { + if (dest == nullptr || !dest->peer_anchors_.empty()) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool OutDataAnchor::Equal(AnchorPtr anchor) const { + CHECK_FALSE_EXEC(anchor != nullptr, return false); + auto out_data_anchor = Anchor::DynamicAnchorCast(anchor); + if (out_data_anchor != nullptr) { + if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) { + return true; + } + } + return false; +} + +bool OutDataAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return DataAnchor::IsTypeOf(type); +} + +ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} + +ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} + +bool ControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return Anchor::IsTypeOf(type); +} + +InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} + +InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} + +InControlAnchor::Vistor InControlAnchor::GetPeerOutControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto out_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (out_control_anchor != nullptr) { + ret.push_back(out_control_anchor); + } + } + return InControlAnchor::Vistor(shared_from_this(), ret); +} + +InControlAnchor::Vistor InControlAnchor::GetPeerOutDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto out_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (out_data_anchor != nullptr) { + ret.push_back(out_data_anchor); + } + } + return InControlAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { + if (src == nullptr) { + GELOGE(GRAPH_FAILED, "src anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(src); + src->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool InControlAnchor::Equal(AnchorPtr anchor) const { + CHECK_FALSE_EXEC(anchor != nullptr, return false); + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor); + if (in_control_anchor != nullptr) { + if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) { + return true; + } + } + return false; +} + +bool InControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return ControlAnchor::IsTypeOf(type); +} + +OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} + +OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} + +OutControlAnchor::Vistor OutControlAnchor::GetPeerInControlAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_control_anchor != nullptr) { + ret.push_back(in_control_anchor); + } + } + return OutControlAnchor::Vistor(shared_from_this(), ret); +} + +OutControlAnchor::Vistor OutControlAnchor::GetPeerInDataAnchors() const { + vector ret; + for (const auto &anchor : peer_anchors_) { + auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); + if (in_data_anchor != nullptr) { + ret.push_back(in_data_anchor); + } + } + return OutControlAnchor::Vistor(shared_from_this(), ret); +} + +graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { + if (dest == nullptr) { + GELOGE(GRAPH_FAILED, "dest anchor is invalid."); + return GRAPH_FAILED; + } + peer_anchors_.push_back(dest); + dest->peer_anchors_.push_back(shared_from_this()); + return GRAPH_SUCCESS; +} + +bool OutControlAnchor::Equal(AnchorPtr anchor) const { + auto out_control_anchor = Anchor::DynamicAnchorCast(anchor); + if (out_control_anchor != nullptr) { + if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) { + return true; + } + } + return false; +} + +bool OutControlAnchor::IsTypeOf(TYPE type) const { + if (strcmp(Anchor::TypeOf(), type) == 0) { + return true; + } + return ControlAnchor::IsTypeOf(type); +} +} // namespace ge diff --git a/metadef/graph/ascend_string.cc b/metadef/graph/ascend_string.cc new file mode 100644 index 00000000..21d46a8c --- /dev/null +++ b/metadef/graph/ascend_string.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/ascend_string.h" +#include "debug/ge_log.h" + +namespace ge { +AscendString::AscendString(const char* name) { + if (name != nullptr) { + name_ = std::shared_ptr(new (std::nothrow) std::string(name)); //lint !e1524 + if (name_ == nullptr) { + GELOGE(FAILED, "AscendString[%s] make shared failed.", name); + } + } +} + +const char* AscendString::GetString() const { + if (name_ == nullptr) { + return nullptr; + } + + return (*name_).c_str(); +} + +bool AscendString::operator<(const AscendString& d) const { + if (name_ == nullptr && d.name_ == nullptr) { + return false; + } else if (name_ == nullptr) { + return true; + } else if (d.name_ == nullptr) { + return false; + } + return (*name_ < *(d.name_)); +} + +bool AscendString::operator>(const AscendString& d) const { + if (name_ == nullptr && d.name_ == nullptr) { + return false; + } else if (name_ == nullptr) { + return false; + } else if (d.name_ == nullptr) { + return true; + } + return(*name_ > *(d.name_)); +} + +bool AscendString::operator==(const AscendString& d) const { + if (name_ == nullptr && d.name_ == nullptr) { + return true; + } else if (name_ == nullptr) { + return false; + } else if (d.name_ == nullptr) { + return false; + } + return (*name_ == *(d.name_)); +} + +bool AscendString::operator<=(const AscendString& d) const { + if (name_ == nullptr) { + return true; + } else if (d.name_ == nullptr) { + return false; + } + return (*name_ <= *(d.name_)); +} + +bool AscendString::operator>=(const AscendString& d) const { + if (d.name_ == nullptr) { + return true; + } else if (name_ == nullptr) { + return false; + } + return (*name_ >= *(d.name_)); +} + +bool AscendString::operator!=(const AscendString& d) const { + if (name_ == nullptr && d.name_ == nullptr) { + return false; + } else if (name_ == nullptr) { + return true; + } else if (d.name_ == nullptr) { + return true; + } + return (*name_ != *(d.name_)); +} +} // namespace ge diff --git a/metadef/graph/attr_value.cc b/metadef/graph/attr_value.cc new file mode 100644 index 00000000..6e5743ba --- /dev/null +++ b/metadef/graph/attr_value.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/attr_value.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" + +namespace ge { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { + impl = ComGraphMakeShared(); +} + +#define ATTR_VALUE_SET_GET_IMP(type) \ + graphStatus AttrValue::GetValue(type &val) const { \ + if (impl != nullptr) { \ + GELOGW("GetValue failed."); \ + return impl->geAttrValue_.GetValue(val); \ + } \ + return GRAPH_FAILED; \ + } + +ATTR_VALUE_SET_GET_IMP(AttrValue::STR) +ATTR_VALUE_SET_GET_IMP(AttrValue::INT) +ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) + +graphStatus AttrValue::GetValue(AscendString &val) { + std::string val_get; + auto status = GetValue(val_get); + if (status != GRAPH_SUCCESS) { + return status; + } + val = AscendString(val_get.c_str()); + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/buffer.cc b/metadef/graph/buffer.cc new file mode 100644 index 00000000..fd3e174d --- /dev/null +++ b/metadef/graph/buffer.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/buffer.h" +#include "proto/ge_ir.pb.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +Buffer::Buffer() { + data_.InitDefault(); + if (data_.GetProtoMsg()) { + buffer_ = data_.GetProtoMsg()->mutable_bt(); + } +} + +Buffer::Buffer(const Buffer &other) { + // Share data + data_ = other.data_; + buffer_ = other.buffer_; +} + +Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default + auto proto_msg = data_.GetProtoMsg(); + if (proto_msg != nullptr) { + try { + proto_msg->set_bt(std::string(buffer_size, default_val)); + buffer_ = proto_msg->mutable_bt(); + } catch (std::bad_alloc &e) { + GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); + buffer_ = nullptr; + } + } +} + +Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) { + Buffer buffer; + auto proto_msg = buffer.data_.GetProtoMsg(); + if (proto_msg != nullptr && data != nullptr) { + try { + proto_msg->set_bt(data, buffer_size); + buffer.buffer_ = proto_msg->mutable_bt(); + } catch (std::bad_alloc &e) { + GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); + buffer.buffer_ = nullptr; + } + } + return buffer; +} + +Buffer::Buffer(const std::shared_ptr &proto_owner, proto::AttrDef *buffer) + : data_(proto_owner, buffer) { + if (data_.GetProtoMsg() != nullptr) { + buffer_ = data_.GetProtoMsg()->mutable_bt(); + } +} + +Buffer::Buffer(const std::shared_ptr &proto_owner, std::string *buffer) + : data_(proto_owner, nullptr) { + buffer_ = buffer; +} + +Buffer &Buffer::operator=(const Buffer &other) { + if (&other != this) { + // Share data + data_ = other.data_; + buffer_ = other.buffer_; + } + return *this; +} + +const std::uint8_t *Buffer::GetData() const { + if (buffer_ != nullptr) { + return (const std::uint8_t *)buffer_->data(); + } + return nullptr; +} + +std::uint8_t *Buffer::GetData() { + if (buffer_ != nullptr && !buffer_->empty()) { + // Avoid copy on write + (void)(*buffer_)[0]; + return reinterpret_cast(const_cast(buffer_->data())); + } + return nullptr; +} + +std::size_t Buffer::GetSize() const { + if (buffer_ != nullptr) { + return buffer_->size(); + } + return 0; +} + +void Buffer::ClearBuffer() { + if (buffer_ != nullptr) { + buffer_->clear(); + } +} +} // namespace ge diff --git a/metadef/graph/compute_graph.cc b/metadef/graph/compute_graph.cc new file mode 100644 index 00000000..e85fc1af --- /dev/null +++ b/metadef/graph/compute_graph.cc @@ -0,0 +1,1304 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/compute_graph.h" +#include +#include "./format_refiner.h" +#include "./ge_context.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "ge/ge_api_types.h" +#include "graph/shape_refiner.h" +#include "proto/ge_ir.pb.h" +#include "utils/ge_ir_utils.h" +#include "utils/graph_utils.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/string_utils.h" +#include "utils/tensor_utils.h" + +namespace ge { +namespace { +const size_t OUTPUT_PARAM_SIZE = 2; +const std::string alias_name_attr = "_aliasName"; +bool IsUseBFS() { + string run_mode; + const int base = 10; + if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { + return true; + } + } else { + GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); + } + return false; +} +} // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) + : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { + attrs_.InitDefault(); +} + +ComputeGraph::~ComputeGraph() {} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { + return GetAllNodes().size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { + std::vector> subgraphs; + return AllGraphNodes(subgraphs); +} + +ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector> &subgraphs) const { + std::vector all_nodes; + std::deque candidates; + + candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); + while (!candidates.empty()) { + NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } + } + } + + return Vistor(shared_from_this(), all_nodes); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +ComputeGraph::Vistor ComputeGraph::GetNodes(bool is_unknown_shape) const { + if (is_unknown_shape) { + return GetDirectNode(); + } else { + return GetAllNodes(); + } +} + + +size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { + return Vistor(shared_from_this(), nodes_); +} + +ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { + return Vistor(shared_from_this(), input_nodes_); +} + +ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { + std::vector result; + for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) { + result.push_back(iter->first); + } + return Vistor(shared_from_this(), result); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + if (node->GetName() == name) { + return node; + } + std::vector out_alias_name; + if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) { + for (const auto &alias_name : out_alias_name) { + if (alias_name == name) { + return node; + } + } + } + } + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +NodePtr ComputeGraph::FindFirstNodeMatchType(const std::string &name) const { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + if (node->GetType() == name) { + return node; + } + } + return nullptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( + const ComputeGraph &r_graph) const { + // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored + if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { + const auto &proto_attr_map = *(this->attrs_.protoMsg_); + const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_); + // 1.Verify graph's ProtoAttrMap size + if (proto_attr_map.size() != r_proto_attr_map.size()) { + GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.", + this->GetName().c_str()); + return false; + } + // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented + for (const auto &it : proto_attr_map) { + if (r_proto_attr_map.count(it.first) == 0) { + GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.", + this->GetName().c_str(), it.first.c_str()); + return false; + } + } + return true; + } + return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr)); +} + +/// Since there may be different input nodes +/// chosen by user in the same graph, special judgment is needed +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( + const std::vector &left_nodes, const std::vector &right_nodes) const { + const auto left_nodes_size = left_nodes.size(); + const auto right_nodes_size = right_nodes.size(); + if (left_nodes_size != right_nodes_size) { + GELOGE(GRAPH_FAILED, + "Check failed with graph input_nodes_: " + "left inputNodes size %zu is different with right inputNodes size %zu .", + left_nodes_size, right_nodes_size); + return false; + } + for (size_t j = 0; j < left_nodes_size; j++) { + if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) { + GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); + return false; + } + const auto &left_input_name = left_nodes.at(j)->GetName(); + const auto &right_input_name = right_nodes.at(j)->GetName(); + if (left_input_name != right_input_name) { + GELOGE(GRAPH_FAILED, + "Check failed with graph input_nodes_: " + "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", + left_input_name.c_str(), right_input_name.c_str(), j); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( + const ComputeGraph &r_graph) const { + return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && + IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && + VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && + IsEqual(this->name_, r_graph.name_, "graph.name_") && + IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") && + IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") && + IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") && + IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") && + IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") && + IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") && + IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") && + IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const { + // Firstly: Graph's members equal + if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) { + return false; + } + + // Secondly: Node equal means the link relationship between node and node itself equal + for (const auto &left_node : nodes_) { + if (left_node == nullptr) { + GELOGE(GRAPH_FAILED, "left_node is nullptr"); + return false; + } + const auto &node_name = left_node->GetName(); + // After TopologicalSorting, node order can change, so find node by name + const auto &right_node = r_graph.FindNode(node_name); + GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false); + if (!(*right_node == *left_node)) { + GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str()); + return false; + } + } + + // Thirdly: Recursively determine whether the sub graphs are equal + for (size_t i = 0; i < this->sub_graph_.size(); i++) { + if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) { + return false; + } + } + return true; +} + +NodePtr ComputeGraph::AddNodeFront(NodePtr node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); + return nullptr; + } + node->SetHostNode(is_valid_flag_); + node->GetOpDesc()->SetId(nodes_.size()); + if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { + (void)nodes_.insert(nodes_.begin() + 1, node); + } else { + (void)nodes_.insert(nodes_.begin(), node); + } + return node; +} + +NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(nodes_.size()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + return AddNodeFront(node_ptr); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return nullptr; + } + node->SetHostNode(is_valid_flag_); + node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); + nodes_.push_back(node); + return node; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(GetDirectNodesSize()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + return AddNode(node_ptr); +} + +NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(id); + NodePtr node = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + node->SetHostNode(is_valid_flag_); + nodes_.push_back(node); + return node; +} + +NodePtr ComputeGraph::AddInputNode(NodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return nullptr; + } + input_nodes_.push_back(node); + if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { + GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); + } + return node; +} + +NodePtr ComputeGraph::AddOutputNode(NodePtr node) { + return AddOutputNodeByIndex(node, 0); +} + +NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); + return nullptr; + } + + bool already_have = false; + NodePtr result = node; + // [output_nodes_info_ : should not be null] + for (const auto &item : output_nodes_info_) { + if (item.first->GetName() == node->GetName() && item.second == index) { + already_have = true; + result = item.first; + break; + } + } + + if (!already_have) { + output_nodes_info_.emplace_back(std::make_pair(node, index)); + GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index); + } + + if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { + GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); + } + return result; +} + +graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { + continue; + } + if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { + GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, + "Remove edge from const op failed."); + if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { + GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); + auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); + if (iter != nodes_.end()) { + (void)nodes_.erase(iter); + } + } + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // delete const op for this node + (void)RemoveConstInput(node); + + // if the node save as input node, delete it + (void)RemoveInputNode(node); + + // if the node save as input node, delete it + (void)RemoveOutputNode(node); + + if (GRAPH_SUCCESS != IsolateNode(node)) { + GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto iter = find(nodes_.begin(), nodes_.end(), node); + if (iter != nodes_.end()) { + (void)nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +// Used in sub_graph scenes +graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); + if (iter != input_nodes_.end()) { + (void)input_nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +// Used in sub_graph scenes +graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + auto iter = output_nodes_info_.begin(); + bool find_node = false; + // [output_nodes_info_ : should not be null] + while (iter != output_nodes_info_.end()) { + if (node->GetName() == iter->first->GetName()) { + iter = output_nodes_info_.erase(iter); + find_node = true; + } else { + ++iter; + } + } + GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); + return GRAPH_SUCCESS; +} + +std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { + if (sub_graph == nullptr) { + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); + return nullptr; + } + sub_graph_.push_back(sub_graph); + names_to_subgraph_[sub_graph->GetName()] = sub_graph; + return sub_graph; +} + +graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { + if (sub_graph == nullptr) { + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); + return GRAPH_FAILED; + } + + names_to_subgraph_.erase(sub_graph->GetName()); + auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); + if (iter != sub_graph_.end()) { + (void)sub_graph_.erase(iter); + return GRAPH_SUCCESS; + } else { + GELOGW("find sub_graph failed"); + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr &subgraph) { + if (subgraph == nullptr) { + GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + auto parent_graph = subgraph->GetParentGraph(); + if (parent_graph == nullptr) { + GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + auto parent_node = subgraph->GetParentNode(); + if (parent_node == nullptr) { + GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); + return GRAPH_PARAM_INVALID; + } + if (parent_node->GetOwnerComputeGraph() != parent_graph) { + GE_LOGE( + "Try to add a subgraph which parent node's parent graph is not equal to " + "the subgraph's parent graph, subgraph name %s, parent node name %s", + subgraph->GetName().c_str(), parent_graph->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + if (!this->parent_graph_.expired()) { + GELOGW("The subgraphs should only be added to the root graph"); + } + if (name != subgraph->GetName()) { + GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); + } + if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { + GE_LOGE("The subgraph %s existed", name.c_str()); + return GRAPH_PARAM_INVALID; + } + sub_graph_.push_back(subgraph); + names_to_subgraph_[name] = subgraph; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::AddSubgraph(const std::shared_ptr &subgraph) { + if (subgraph == nullptr) { + return GRAPH_PARAM_INVALID; + } + return AddSubgraph(subgraph->GetName(), subgraph); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { + auto iter = names_to_subgraph_.find(name); + if (iter == names_to_subgraph_.end()) { + return; + } + for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { + if (*vec_iter == iter->second) { + sub_graph_.erase(vec_iter); + break; + } + } + names_to_subgraph_.erase(iter); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( + const std::shared_ptr &subgraph) { + if (subgraph != nullptr) { + RemoveSubgraph(subgraph->GetName()); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( + const std::string &name) const { + std::shared_ptr parent = parent_graph_.lock(); + if (parent == nullptr) { + auto iter = names_to_subgraph_.find(name); + return iter == names_to_subgraph_.end() ? nullptr : iter->second; + } else { + return parent->GetSubgraph(name); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> +ComputeGraph::GetAllSubgraphs() const { + return sub_graph_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentGraph() { + return parent_graph_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( + const shared_ptr &parent) { + parent_graph_ = parent; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentNode() { + return parent_node_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr &parent) { + parent_node_ = parent; +} + +/// +/// @brief Update input-mapping +/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { + for (auto &input : nodes_) { + if (input->GetType() == DATA) { + uint32_t cur_index = 0; + if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { + continue; + } + auto iter = input_mapping.find(cur_index); + if (iter == input_mapping.end()) { + continue; + } + if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +/// +/// @brief Update output-mapping +/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { + NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT); + if (net_output == nullptr) { + GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT); + return GRAPH_FAILED; + } + OpDescPtr op_desc = net_output->GetOpDesc(); + if (op_desc == nullptr) { + GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); + return GRAPH_FAILED; + } + + size_t num = op_desc->GetAllInputsSize(); + for (size_t i = 0; i < num; i++) { + GeTensorDesc tensor = op_desc->GetInputDesc(i); + uint32_t cur_index = 0; + if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { + continue; + } + auto iter = output_mapping.find(cur_index); + if (iter == output_mapping.end()) { + continue; + } + if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); + return GRAPH_FAILED; + } + if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { + GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { + std::vector node_vec = nodes_; + for (const auto &node : GetDirectNode()) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGW("node or OpDescPtr is nullptr."); + continue; + } + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); + if (node->GetOpDesc()->GetType() == RECV) { + auto iter = find(node_vec.begin(), node_vec.end(), node); + if (iter == node_vec.end()) { + GELOGW("no node found."); + } else { + (void)node_vec.erase(iter); + } + + auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); + (void)node_vec.insert(dst_iter, node); + } + if (node->GetOpDesc()->GetType() == SEND) { + auto iter = find(node_vec.begin(), node_vec.end(), node); + if (iter == node_vec.end()) { + GELOGW("no node found."); + } else { + (void)node_vec.erase(iter); + } + + auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0)); + (void)node_vec.insert(src_iter + 1, node); + } + } + nodes_.clear(); + for (size_t i = 0; i < node_vec.size(); ++i) { + NodePtr node = node_vec[i]; + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGW("node or OpDescPtr is nullptr."); + } else { + node->GetOpDesc()->SetId((int64_t)i); + nodes_.push_back(node); + } + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, + std::map &map_in_edge_num, + std::vector &stack, bool reverse) { + GELOGD("Runing_Dfs_Sort: %s", name_.c_str()); + // Record the number of non data nodes but no input nodes + GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); + std::vector out_nodes; + auto stack_push = [&reverse, &stack](std::vector& out_nodes) { + if (reverse) { + std::reverse(out_nodes.begin(), out_nodes.end()); + } + stack.insert(stack.end(), out_nodes.begin(), out_nodes.end()); + out_nodes.clear(); + }; + // Only data nodes here + while (!stack.empty()) { + NodePtr node = stack.back(); + stack.pop_back(); + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + for (const auto &anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(anchor); + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); + } + } + stack_push(out_nodes); + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); + } + } + stack_push(out_nodes); + } + GE_IF_BOOL_EXEC( + node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor + : node->GetOutControlAnchor()->GetPeerAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + out_nodes.push_back(peer_in_anchor->GetOwnerNode()); + } + } + stack_push(out_nodes);) + } + + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, + std::map &map_in_edge_num, + std::deque &stack) { + GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); + std::vector stack_input; + std::map breadth_node_map; + // Record the number of non data nodes but no input nodes + GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); + + // Only data nodes here + while (!stack_input.empty() || !stack.empty()) { + NodePtr node = nullptr; + if (!stack.empty()) { + node = stack.back(); + stack.pop_back(); + } else { + node = stack_input.back(); + stack_input.pop_back(); + } + + node_vec.push_back(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); + CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); + + for (const auto &name_node : breadth_node_map) { + (void)stack.push_front(name_node.second); + } + breadth_node_map.clear(); + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, + std::map &breadth_node_map) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { + auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); + if (iter != map_in_edge_num.end() && --iter->second == 0) { + (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); + } + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { + auto ret = TopologicalSortingGraph(); + if (ret != SUCCESS) { + GraphUtils::DumpGEGraphToOnnx(*this, "black_box"); + GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); + return ret; + } + + if (sub_graph_.empty()) { + return SUCCESS; + } + + // partition sub graph + for (const auto &sub_graph : sub_graph_) { + ret = sub_graph->TopologicalSortingGraph(); + if (ret != SUCCESS) { + GELOGE(ret, "Sub graph topological sort Failed"); + return ret; + } + } + + std::vector> subgraphs; + auto nodes = AllGraphNodes(subgraphs); + for (size_t i = 0; i < nodes.size(); i++) { + NodePtr node = nodes.at(i); // [node: should not be null] + node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] + } + if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original + GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); + return SUCCESS; + } + sub_graph_.swap(subgraphs); + return SUCCESS; +} + +graphStatus ComputeGraph::TopologicalSortingGraph(bool dfs_reverse) { + std::vector node_vec; + std::map map_in_edge_num; + bool use_BFS = IsUseBFS(); + if (use_BFS) { + std::deque stack; + if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } else { + std::vector stack; + if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack, dfs_reverse) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + + // If they are not equal, there is a closed loop + if (node_vec.size() != nodes_.size()) { + std::set itered_nodes_set; + for (auto &node : node_vec) { + itered_nodes_set.insert(node.get()); + } + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"TopologicalSortingGraph", "exist closed loop in graph"}); + GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(), + node_vec.size()); + for (auto &node : nodes_) { + if (itered_nodes_set.count(node.get()) == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"TopologicalSortingGraph", "op[" + node->GetName() + "] does not itered when topological sorting"}); + GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); + } + } + return GRAPH_FAILED; + } + + nodes_.clear(); + for (size_t i = 0; i < node_vec.size(); i++) { + NodePtr node = node_vec[i]; // [node: should not be null] + node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] + nodes_.push_back(node); + } + + is_valid_flag_ = true; + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::SortNodes(std::vector &stack, std::map &map_in_edge_num) { + // Record the number of non data nodes but no input nodes + uint32_t spec_node_size = 0; + bool verify_isolated = false; + string run_mode; + const int base = 10; + // Need verify isolated point in PREDICTION mode. + if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { + if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) { + verify_isolated = true; + } + } + for (const auto &node : GetDirectNode()) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); + map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); + if (map_in_edge_num[node] == 0) { + if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && + (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { + // At present, can only judge the isolated point without input and output. + // It is impossible to judge the situation with multiple output nodes. + if (verify_isolated && GetOutEdgeSize(node) == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SortNodes", "may has isolated node[" + node->GetName() + "] in graph"}); + GELOGE(GRAPH_FAILED, "May has isolated node in graph, node name: %s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + (void)stack.insert(stack.begin(), node); + spec_node_size++; + continue; + } + // Need to insert the data nodes in reverse order + (void)stack.insert(stack.begin() + spec_node_size, node); + } + } + + /// Make sure the inputs order matches with user-designated + /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) + /// 2. Compare two indices, if not match, swap the positions of two inputs + /// *: Remind: stack is reverse-order + for (size_t i = 0; i < stack.size(); ++i) { + // If not found in 'inputs_order_', skip it + auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); + GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); + auto inx_i = it_i - inputs_order_.begin(); + for (size_t j = i + 1; j < stack.size(); ++j) { + // If not found in 'inputs_order_', skip it + auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); + GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); + + // Compare index, swap them if it should be + auto inx_j = it_j - inputs_order_.begin(); + GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); + } + } + + return GRAPH_SUCCESS; +} + +size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { + size_t in_edge_size = 0; + if (node == nullptr) { + return in_edge_size; + } + for (const auto &anchor : node->GetAllInDataAnchors()) { + in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); + // Break flow control data loop. + OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); + if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { + NodePtr out_node = out_anchor->GetOwnerNode(); + if (out_node == nullptr) { + GELOGW("out node is nullptr"); + continue; + } + if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) { + GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed"); + return in_edge_size); + in_edge_size -= 1; + } + } + } + if (node->GetInControlAnchor() != nullptr) { + in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); + } + return in_edge_size; +} + +size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { + size_t out_edge_size = 0; + if (node == nullptr) { + return out_edge_size; + } + + // Break flow control data loop. + if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { + if (anchor != nullptr) { + out_edge_size = out_edge_size + anchor->GetPeerAnchors().size(); + } + } + } + if (node->GetOutControlAnchor() != nullptr) { + if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { + return 0; + } + out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); + } + return out_edge_size; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { + GELOGI("graph name = %s.", GetName().c_str()); + for (const auto &node : GetAllNodes()) { + GELOGD("node name = %s.", node->GetName().c_str()); + for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + } + auto out_control_anchor = node->GetOutControlAnchor(); + if (out_control_anchor != nullptr) { + for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, + GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str())); + } + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { + this->AttrHolder::Swap(graph); + + origGraph_.swap(graph.origGraph_); + + name_.swap(graph.name_); + std::swap(graph_id_, graph.graph_id_); + attrs_.Swap(graph.attrs_); + nodes_.swap(graph.nodes_); + all_nodes_infos_.swap(graph.all_nodes_infos_); + target_nodes_info_.swap(graph.target_nodes_info_); + + input_nodes_.swap(graph.input_nodes_); + inputs_order_.swap(graph.inputs_order_); + std::swap(input_size_, graph.input_size_); + out_nodes_map_.swap(graph.out_nodes_map_); + std::swap(output_size_, graph.output_size_); + output_nodes_info_.swap(graph.output_nodes_info_); + + sub_graph_.swap(graph.sub_graph_); + names_to_subgraph_.swap(graph.names_to_subgraph_); + parent_graph_.swap(graph.parent_graph_); + parent_node_.swap(graph.parent_node_); + + // the members followed should not in the ComputeGraph class + std::swap(is_valid_flag_, graph.is_valid_flag_); + std::swap(is_summary_graph_, graph.is_summary_graph_); + std::swap(need_iteration_, graph.need_iteration_); + params_share_map_.swap(graph.params_share_map_); + op_name_map_.swap(graph.op_name_map_); + std::swap(session_id_, graph.session_id_); + std::swap(data_format_, graph.data_format_); + std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); + + // Update Node owner. + SetNodesOwner(); + graph.SetNodesOwner(); +} + +void ComputeGraph::SetNodesOwner() { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + node->SetOwnerComputeGraph(shared_from_this()); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto next_nodes = node->GetOutAllNodes(); + // If there is input data side + for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) { + auto in_data_anchor = node->GetInDataAnchor(static_cast(i)); + auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (pre_out_data_anchor != nullptr) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || + pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, + continue); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor(); + GE_CHECK_NOTNULL(pre_out_ctrl_anchor); + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + } + + // If there is an input control side + auto in_ctrl_anchor = node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, + "remove edge failed"); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr) { + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + } + + for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, + "remove edge failed"); + for (const auto &next_node : next_nodes) { + auto next_in_control_anchor = next_node->GetInControlAnchor(); + GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "add edge failed"); + } + } + + return RemoveExtraOutEdge(node); +} + +graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + // Remove redundant output edges + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + + for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + } + auto out_ctrl_anchor = node->GetOutControlAnchor(); + if (out_ctrl_anchor != nullptr) { + for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, + return GRAPH_FAILED, "remove edge failed"); + } + } + return GRAPH_SUCCESS; +} + +graphStatus ComputeGraph::Verify() { + bool is_unknown_graph = GetGraphUnknownFlag(); + for (const auto &node_ptr : GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); + GE_IF_BOOL_EXEC(is_unknown_graph, continue); + GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, + "Verifying %s failed.", node_ptr->GetName().c_str()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() { + return ge::FormatRefiner::InferOrigineFormat(shared_from_this()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() { + GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed."); + for (const auto &node_ptr : GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + auto op_desc = node_ptr->GetOpDesc(); + bool is_need_infer = false; + (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); + if (is_need_infer) { + GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.", + node_ptr->GetName().c_str()); + + graphStatus status = node_ptr->InferShapeAndType(); + GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, + "Op %s does not have the IMPLEMT_INFERFUNC definition," + " and subsequent operators no longer perform shape inference.", + node_ptr->GetName().c_str()); + GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.", + node_ptr->GetName().c_str()); + + for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); + auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size()); + (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor); + } + } + } + } + return GRAPH_SUCCESS; +} + +ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; } + +ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const { + return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); +} + +const std::map &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; } + +void ComputeGraph::SetUserDefOutput(const std::string &output_name) { + if (output_name.empty()) { + return; + } + + vector nodes = StringUtils::Split(output_name, ';'); + for (string node : nodes) { + vector item = StringUtils::Split(node, ':'); + if (item.size() != OUTPUT_PARAM_SIZE) { + GELOGW("invalid output param!input:%s", output_name.c_str()); + continue; + } + + int32_t index; + try { + index = stoi(StringUtils::Trim(item[1])); + } catch (const std::out_of_range &) { + GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str()); + continue; + } catch (const std::invalid_argument &) { + GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str()); + continue; + } catch (...) { + GELOGW("stoi fail! output_name:%s", output_name.c_str()); + continue; + } + auto iter = out_nodes_map_.find(item[0]); + if (iter == out_nodes_map_.end()) { + out_nodes_map_[item[0]] = std::vector(1, index); + } else { + auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index); + if (idx_iter == iter->second.end()) { + iter->second.push_back(index); + } + } + } +} + +const std::string ComputeGraph::GetOutput() { + static const int resultDefaultSize = 2048; + string result; + result.reserve(resultDefaultSize); + auto iter = out_nodes_map_.begin(); + while (iter != out_nodes_map_.end()) { + auto idxes = iter->second; + for (auto idx : idxes) { + (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";"); + } + ++iter; + } + + return result.substr(0, result.length() - 1); +} +} // namespace ge diff --git a/metadef/graph/debug/ge_log.h b/metadef/graph/debug/ge_log.h new file mode 100644 index 00000000..03f0e3ac --- /dev/null +++ b/metadef/graph/debug/ge_log.h @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ +#define COMMON_GRAPH_DEBUG_GE_LOG_H_ + +#include "graph/ge_error_codes.h" +#include "framework/common/debug/ge_log.h" + +#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) + +#define GE_LOGI_IF(condition, ...) \ + if ((condition)) { \ + GELOGI(__VA_ARGS__); \ + } + +#define GE_LOGW_IF(condition, ...) \ + if ((condition)) { \ + GELOGW(__VA_ARGS__); \ + } + +#define GE_LOGE_IF(condition, ...) \ + if ((condition)) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + } + +#define GE_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (ge::SUCCESS != _status) { \ + return _status; \ + } \ + } while (0) + +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +#define GE_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (_status) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// If expr is true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// Only check error log +#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + } \ + } while (0) + +// If expr is not true, do not print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + return _status; \ + } \ + } while (0) + +// If expr is not true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not true, the log is printed and a custom statement is executed +#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not GRAPH_SUCCESS, print the log and return the same value +#define GE_CHK_STATUS_RET(expr, ...) \ + do { \ + const ge::graphStatus _status = (expr); \ + if (ge::SUCCESS != _status) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ + try { \ + exec_expr0; \ + } catch (...) { \ + GELOGE(ge::FAILED, "Make shared failed"); \ + exec_expr1; \ + } + +#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ + diff --git a/metadef/graph/debug/ge_op_types.h b/metadef/graph/debug/ge_op_types.h new file mode 100644 index 00000000..124ec077 --- /dev/null +++ b/metadef/graph/debug/ge_op_types.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ +#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ + +#include "graph/compiler_options.h" + +namespace ge { +#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name METADEF_ATTRIBUTE_UNUSED = str_name + +GE_REGISTER_OPTYPE(DATA, "Data"); +GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); +GE_REGISTER_OPTYPE(MATMUL, "MatMul"); +GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); +GE_REGISTER_OPTYPE(PERMUTE, "Permute"); +GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); +GE_REGISTER_OPTYPE(_WHILE, "_While"); +GE_REGISTER_OPTYPE(WHILE, "While"); +GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); +GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); +GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); +GE_REGISTER_OPTYPE(SWITCH, "Switch"); +GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); +GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); +GE_REGISTER_OPTYPE(MERGE, "Merge"); +GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); +GE_REGISTER_OPTYPE(ENTER, "Enter"); +GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); +GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); +GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); +GE_REGISTER_OPTYPE(CONSTANT, "Const"); +GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); +GE_REGISTER_OPTYPE(END, "End"); +GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); +GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); +GE_REGISTER_OPTYPE(INITDATA, "InitData"); +GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); +GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); + +GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); +GE_REGISTER_OPTYPE(VARIABLE, "Variable"); +GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); + +GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); + +// Horovod operator +GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); +GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); +GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); +GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); + +GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); + +GE_REGISTER_OPTYPE(RECV, "Recv"); +GE_REGISTER_OPTYPE(SEND, "Send"); +}; // namespace ge +#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/metadef/graph/debug/ge_util.h b/metadef/graph/debug/ge_util.h new file mode 100644 index 00000000..0bc32b42 --- /dev/null +++ b/metadef/graph/debug/ge_util.h @@ -0,0 +1,274 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ +#define COMMON_GRAPH_DEBUG_GE_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "graph/ge_error_codes.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define GE_DYNAMIC_CAST dynamic_cast +#define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast +#else +#define GE_DYNAMIC_CAST static_cast +#define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif + +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ::ge::optStatus _status = (expr); \ + if (_status) return _status; \ + } while (0) + +#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ + do { \ + const ::ge::optStatus _status = (expr); \ + if (_status) { \ + GELOGI(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Verify whether the parameter is true. If yes, return graph failed and record the error log +#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Verify whether the parameter is false. If yes, return graph failed and record the error log +#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log +#define GE_CHECK_NOTNULL_EXEC(val, expr) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ + expr; \ + } \ + } while (0) + +// Verify whether the parameter is null. If yes, return false and record the error log +#define GE_RT_FALSE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ + return false; \ + } \ + } while (0) + +// Check whether the parameter is out of range +#define GE_CHECK_SIZE(size) \ + do { \ + if (size == 0) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +/// +/// @ingroup GE_common +/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); +/// +#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ + uint32_t _var_name; \ + do { \ + uint32_t _expr_size = (_expr); \ + uint32_t _sizeof_size = (_sizeof); \ + if (_expr_size > (0xffffffff) / _sizeof_size) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + _var_name = _sizeof_size * _expr_size; \ + } while (0); + +// Check whether the container is empty +#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ + do { \ + if (vector.empty()) { \ + GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ + return ge::GRAPH_FAILED; \ + } \ + } while (0) + +// Check whether the container is empty and return the specified status code +#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ + do { \ + if (vector.empty()) { \ + GELOGE(_status, "param[%s] is empty", #vector); \ + return _status; \ + } \ + } while (0) + +/// +/// @ingroup GE_common +/// @brief This macro provides the ability to disable copying constructors and assignment operators. +/// It is usually placed under private +/// +#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &) = delete; \ + void operator=(const TypeName &) = delete + +/// Check whether the size is 0 or out of range +/// @param:size:Size to be verified +#define GE_CHECK_SIZE_RANGE(size) \ + do { \ + if (size == 0 || size >= UINT_MAX / 4) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_SHORT_SIZE_RANGE(size) \ + do { \ + if (size == 0 || size >= UINT_MAX / 2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ + do { \ + if (size <= 0) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ + do { \ + if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify that the value on the left is greater than or equal to the value on the right +#define GE_CHECK_GE(lhs, rhs) \ + do { \ + if (lhs < rhs) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Check whether the parameters are equal +#define GE_CHECK_EQ(val1, val2) \ + do { \ + if (val1 != val2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Verify that the value on the left is less than or equal to the value on the right +#define GE_CHECK_LE(lhs, rhs) \ + do { \ + if (lhs > rhs) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// Check whether the parameters are equal +#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ + do { \ + if (val1 != val2) { \ + GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ + return ge::GRAPH_PARAM_INVALID; \ + } \ + } while (0) + +// If expr is false, the custom statement is executed +#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } while (0) + +#define GE_DELETE_NEW_SINGLE(var) \ + do { \ + if (var != nullptr) { \ + delete var; \ + var = nullptr; \ + } \ + } while (0) + +#define GE_DELETE_NEW_ARRAY(var) \ + do { \ + if (var != nullptr) { \ + delete[] var; \ + var = nullptr; \ + } \ + } while (0) + +template +static inline std::shared_ptr ComGraphMakeShared(Args &&... args) { + using T_nc = typename std::remove_const::type; + std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); + return ret; +} + +#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ diff --git a/metadef/graph/debug/graph_debug.cc b/metadef/graph/debug/graph_debug.cc new file mode 100644 index 00000000..20309a09 --- /dev/null +++ b/metadef/graph/debug/graph_debug.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/debug/graph_debug.h" +#include +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" + +using namespace std; + +#define TAB " " +#define STR_FMT(str) (" \"" + std::string(str) + "\" ") +#define INPUT_ANCHOR_PORT(name) ("__input__" + (name)) +#define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name)) + +namespace ge { +std::unordered_set control_anchor; +std::vector types = { + "DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", + "DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", + "DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; + +std::vector formats = {"FORMAT_NCHW", + "FORMAT_NHWC", + "FORMAT_ND", + "FORMAT_NC1HWC0", + "FORMAT_FRACTAL_Z", + "FORMAT_NC1C0HWPAD", + "FORMAT_NHWC1C0", + "FORMAT_FSR_NCHW", + "FORMAT_FRACTAL_DECONV", + "FORMAT_C1HWNC0", + "FORMAT_FRACTAL_DECONV_TRANSPOSE", + "FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", + "FORMAT_NC1HWC0_C04", + "FORMAT_FRACTAL_Z_C04", + "FORMAT_CHWN", + "FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", + "FORMAT_HWCN", + "FORMAT_NC1KHKWHWC0", + "FORMAT_BN_WEIGHT", + "FORMAT_FILTER_HWCK", + "FORMAT_HASHTABLE_LOOKUP_LOOKUPS", + "FORMAT_HASHTABLE_LOOKUP_KEYS", + "FORMAT_HASHTABLE_LOOKUP_VALUE", + "FORMAT_HASHTABLE_LOOKUP_OUTPUT", + "FORMAT_HASHTABLE_LOOKUP_HITS", + "FORMAT_RESERVED"}; + +std::vector data_nodes = {"Const", "Data"}; + +void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) { + if (node == nullptr) { + GELOGI("Some nodes are null."); + return; + } + + bool in_control = false; + auto name = node->GetName(); + out_ << TAB << STR_FMT(name); + auto input_cnt = std::max(static_cast(1), node->GetAllInDataAnchors().size()); + auto output_cnt = std::max(static_cast(1), node->GetAllOutDataAnchors().size()); + if (control_anchor.find(node->GetName()) != control_anchor.end()) { + input_cnt++; + in_control = true; + } + auto max_col = input_cnt * output_cnt; + out_ << "[\n"; + if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) { + out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n"; + } else { + out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n"; + } + out_ << TAB << TAB << "label=<\n"; + out_ << TAB << TAB << R"(" << std::endl; + + auto input_anchors = node->GetAllInDataAnchors(); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return); + if (!input_anchors.empty()) { + out_ << TAB << TAB << ""; + } + for (const auto &anchor : input_anchors) { + string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx()); + + out_ << ""; + } + if (in_control) { + string anchor_text = "ctrl"; + out_ << ""; + } + if (!input_anchors.empty()) { + out_ << "\n"; + } + // Node type + out_ << TAB << TAB << "\n"; + // Output + auto output_anchors = node->GetAllOutDataAnchors(); + if (!output_anchors.empty()) { + out_ << TAB << TAB << ""; + } + for (const auto &anchor : output_anchors) { + string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx()); + + out_ << ""; + } + + if (!output_anchors.empty()) { + out_ << "\n"; + } + out_ << TAB << TAB << "
" + << anchor_text << "" + << anchor_text << "
" + << "" << node->GetType() << "
" + << anchor_text << "
\n" << TAB << ">];\n"; +} + +void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) { + if (node == nullptr) { + GELOGI("Some nodes are null."); + return; + } + auto all_out_anchor = node->GetAllOutDataAnchors(); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return); + for (const auto &anchor : all_out_anchor) { + auto src_anchor = anchor; + auto src_node_name = node->GetName(); + auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast(src_anchor->GetIdx())); + auto des_anchors = anchor->GetPeerAnchors(); + for (const auto &peer_in_anchor : des_anchors) { + auto in_data_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); + std::string dst_node_name; + out_ << TAB << STR_FMT(src_node_name); + out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index); + auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op, continue); + if (in_data_anchor != nullptr) { + dst_node_name = in_data_anchor->GetOwnerNode()->GetName(); + string des_anchor_index = op->GetInputNameByIndex(static_cast(in_data_anchor->GetIdx())); + out_ << " -> " << STR_FMT(dst_node_name); + out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); + out_ << "["; + } + auto in_control_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); + if (in_control_anchor != nullptr) { + dst_node_name = in_control_anchor->GetOwnerNode()->GetName(); + string des_anchor_index = "ctrl"; + out_ << " -> " << STR_FMT(dst_node_name); + out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); + out_ << "["; + out_ << " style=dashed "; + } + if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { + string label; + auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(src_ops, return); + auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); + auto dim = src_shape.GetDims(); + std::ostringstream tensor_info; + if (dim.size() > 0) { + for (size_t i = 0; i < dim.size(); i++) { + if (i != dim.size() - 1) { + tensor_info << dim[i] << "x"; + } else { + tensor_info << dim[i]; + } + } + } else { + tensor_info << "?"; + } + auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); + GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return); + auto format = src_tensor_desc->GetFormat(); + auto datatype = src_tensor_desc->GetDataType(); + tensor_info << " : " << formats[format] << " : " << types[datatype]; + label = tensor_info.str(); + out_ << "label=" << STR_FMT(label); + } + out_ << "]" << std::endl; + } + } +} + +graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, + uint32_t flag) { + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGI("Compute graph is NULL ."); + return GRAPH_SUCCESS; + } + return DumpGraphDotFile(compute_graph, output_dot_file_name, flag); +} + +graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, + uint32_t flag) { + if (graph == nullptr) { + GELOGI("graph is null."); + return GRAPH_SUCCESS; + } + std::ostringstream out_; + out_ << "digraph G{\n"; + out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl; + out_ << TAB << R"(node[fontname="Consolas"])" << std::endl; + out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl; + auto all_nodes = graph->GetAllNodes(); + for (const auto &node : all_nodes) { + for (const auto &temp : node->GetAllOutDataAnchors()) { + for (const auto &peer : temp->GetPeerAnchors()) { + auto temp_control_anchor = Anchor::DynamicAnchorCast(peer); + if (temp_control_anchor) { + (void)control_anchor.insert(peer->GetOwnerNode()->GetName()); + } + } + } + } + for (const auto &node : all_nodes) { + DumpNodeToDot(node, out_); + } + for (const auto &node : all_nodes) { + DumpEdgeToDot(node, out_, flag); + } + out_ << "}"; + std::ofstream output_file(output_dot_file_name); + if (output_file.is_open()) { + output_file << out_.str(); + } else { + GELOGW("%s open error.", output_dot_file_name.c_str()); + } + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/debug/graph_debug.h b/metadef/graph/debug/graph_debug.h new file mode 100644 index 00000000..74aa0e7b --- /dev/null +++ b/metadef/graph/debug/graph_debug.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ +#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ +#include +#include +#include +#include +#include +#include "external/graph/graph.h" +#include "./ge_error_codes.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_log.h" +#include "graph/node.h" +#include "utils/graph_utils.h" + +namespace ge { +enum DotFileFlag { + // Show nodes, edges, size, type and format + DOT_FLAG_DEFAULT = 0, + DOT_NOT_SHOW_EDGE_LABEL = 1, +}; +class GraphDebugPrinter { + public: + static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, + uint32_t flag = DOT_FLAG_DEFAULT); + static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, + uint32_t flag = DOT_FLAG_DEFAULT); + static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_); + static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT); +}; +} // namespace ge + +#endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ diff --git a/metadef/graph/detail/attributes_holder.cc b/metadef/graph/detail/attributes_holder.cc new file mode 100644 index 00000000..e0ea09e8 --- /dev/null +++ b/metadef/graph/detail/attributes_holder.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/attributes_holder.h" +#include +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" +#include "proto/ge_ir.pb.h" + + +namespace ge { +using std::map; +using std::unordered_set; +void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); } +graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) { + if (value.IsEmpty()) { + GELOGE(GRAPH_FAILED, "value is empty, key of the attr is %s", name.c_str()); + return GRAPH_FAILED; + } + auto proto_map = MutableAttrMap().GetProtoMsg(); + auto proto_val = value.value_.GetProtoMsg(); + if (proto_map == nullptr || proto_val == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET && + it->second.value_case() != proto_val->value_case()) { + return GRAPH_FAILED; + } + } + (*proto_map)[name] = *proto_val; + return GRAPH_SUCCESS; +} + +graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { + if (HasAttr(name)) { + return GRAPH_FAILED; + } + requiredAttrs_.push_back(name); + return GRAPH_SUCCESS; +} + +graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const { + auto proto_map = GetAttrMap().GetProtoMsg(); + auto proto_val = value.value_.GetProtoMsg(); + if (proto_map == nullptr || proto_val == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + *proto_val = it->second; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +bool AttrHolder::HasAttr(const std::string &name) const { + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + if (proto_map->find(name) != proto_map->end()) { + return true; + } + } + return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end(); +} + +graphStatus AttrHolder::DelAttr(const std::string &name) { + auto proto_map = MutableAttrMap().GetProtoMsg(); + if (proto_map == nullptr) { + return GRAPH_FAILED; + } + auto it = proto_map->find(name); + if (it != proto_map->end()) { + (void)proto_map->erase(it); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +const std::map AttrHolder::GetAllAttrs() const { + std::map attr_value_map; + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + auto proto_owner = GetAttrMap().GetProtoOwner(); + GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr"); + for (const auto &it : *proto_map) { + attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast(&it.second)); + } + } + return attr_value_map; +} + +const std::unordered_set AttrHolder::GetAllAttrNames() const { + std::unordered_set names; + auto proto_map = GetAttrMap().GetProtoMsg(); + if (proto_map != nullptr) { + for (const auto &it : *proto_map) { + (void)names.insert(it.first); + } + } + for (const string &it : requiredAttrs_) { + (void)names.insert(it); + } + return names; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + return; + } + protoMsg_ = proto_owner.get(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = proto_owner->mutable_attr(); + protoOwner_ = proto_owner; +} + +template <> +void GeIrProtoHelper::InitDefault() { + std::shared_ptr proto_owner; + proto_owner = ComGraphMakeShared(); + if (proto_owner == nullptr) { + GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); + return; + } + protoMsg_ = &proto_owner->attr(); + protoOwner_ = proto_owner; +} +} // namespace ge diff --git a/metadef/graph/format_refiner.cc b/metadef/graph/format_refiner.cc new file mode 100644 index 00000000..690e0445 --- /dev/null +++ b/metadef/graph/format_refiner.cc @@ -0,0 +1,513 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "format_refiner.h" + +#include +#include +#include +#include +#include + +#include "graph/ref_relation.h" +#include "./compute_graph.h" +#include "./ge_error_codes.h" +#include "./graph/ge_tensor.h" +#include "./operator.h" +#include "./operator_factory.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +using namespace ge; +using namespace std; +namespace ge { +namespace { +const size_t kDimSize4d = 4; +const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; +const string kIsGraphInferred = "_is_graph_inferred"; +thread_local RefRelations reflection_builder; +} // namespace + +graphStatus ReflectionProcess(const std::unordered_set &reflection, + std::deque &nodes, + ge::Format to_be_set_format) { + for (const auto &cell : reflection) { + auto node = cell.node; + auto in_out_idx = cell.in_out_idx; + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (cell.in_out == ge::NODE_IN) { + auto desc = node->GetOpDesc()->GetInputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateInputDesc(static_cast(in_out_idx), desc); + } else { + auto desc = node->GetOpDesc()->GetOutputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(in_out_idx), desc); + } + nodes.push_back(cell.node); + } + + return GRAPH_SUCCESS; +} + +graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { + // 5 meas dim num + if (node_ptr->GetType() != "BiasAdd") { + return GRAPH_SUCCESS; + } + std::unordered_map kTfFormatFix = { + {"NHWC", FORMAT_NDHWC}, + {"NCHW", FORMAT_NCDHW} + }; + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(in_desc); + if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = in_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + in_desc->SetOriginFormat(fixed_format); + in_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", + i, node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { + auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); + GE_CHECK_NOTNULL(out_desc); + if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = out_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + out_desc->SetOriginFormat(fixed_format); + out_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", + i, node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { + ConstGeTensorPtr tensor_value; + if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { + GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); + return GRAPH_FAILED; + } + GE_CHECK_NOTNULL(tensor_value); + (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, + std::vector &data_nodes, + std::unordered_map &node_status) { + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "input graph is null"); + return GRAPH_FAILED; + } + anchor_points.clear(); + // Get all anchor point nodes and switch nodes + for (auto &node_ptr : graph->GetAllNodes()) { + if (node_ptr == nullptr) { + return GRAPH_FAILED; + } + auto op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_FAILED; + } + graphStatus status = RefreshConstantOutProcess(graph, op_desc); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); + return GRAPH_FAILED; + } + // consider special node save process + // get all input desc format + bool node_is_all_nd = false; + auto input_size = static_cast(op_desc->GetAllInputsSize()); + for (uint32_t i = 0; i < input_size; i++) { + // Operator pre-set format but not origin format + GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue); + auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); + // Pre-save data node (only main graph data) and default infer fail + if (node_ptr->GetType() == DATA) { + data_nodes.push_back(node_ptr); + } + if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) { + node_is_all_nd = true; + } + } + // Get all output desc format + auto output_size = static_cast(op_desc->GetOutputsSize()); + for (uint32_t i = 0; i < output_size; i++) { + GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue); + auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); + if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { + node_is_all_nd = true; + } + } + // check anchor point valid + if (!node_is_all_nd) { + continue; + } + // special process for biasAdd op + // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg + // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism + // so here do special process + status = BiasAddFormatFixProcess(node_ptr); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); + return GRAPH_FAILED; + } + + GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); + anchor_points.push_back(node_ptr); + } + GELOGI("anchor_points number is %zu", anchor_points.size()); + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node, + std::unordered_map &node_status) { + if (anchor_node == nullptr) { + GELOGE(GRAPH_FAILED, "anchor node is null!"); + return GRAPH_FAILED; + } + std::deque nodes; + nodes.push_back(anchor_node); + while (!nodes.empty()) { + ge::NodePtr node = nodes.front(); + nodes.pop_front(); + graphStatus status = BackInferProcess(nodes, node, node_status); + if (status != GRAPH_SUCCESS && node != nullptr) { + GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str()); + return status; + } + status = ForwardInferProcess(nodes, node, node_status); + if (status != GRAPH_SUCCESS && node != nullptr) { + GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str()); + return status; + } + } + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str()); + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); + auto in_data_anchor_idx = in_anchor->GetIdx(); + auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); + GE_IF_BOOL_EXEC(input_desc == nullptr, continue); + auto to_be_set_format = input_desc->GetOriginFormat(); + if (to_be_set_format == FORMAT_ND) { + GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); + continue; + } + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx); + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str()); + continue; + } + // Check format whether have been set + int idx = peer_out_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", + (peer_out_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } + + auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast(idx)); + if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { + auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); + if (dim_num == 0) { + GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); + continue; + } + /// Check whether node to change dims () + /// Because some node will calculate with 5D, C dim maybe multi meaning + auto peer_out_data_node_type = peer_out_data_node->GetType(); + auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); + // 4 means dims num + if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { + GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", + (peer_out_data_node->GetName()).c_str()); + continue; + } + + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); + + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); + status = peer_out_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_out_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } + } + } + } + return GRAPH_SUCCESS; +} +graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str()); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); + GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); + auto out_data_anchor_idx = out_data_anchor->GetIdx(); + auto to_be_set_format = + node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx))->GetOriginFormat(); + if (to_be_set_format == FORMAT_ND) { + GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); + continue; + } + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); + + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); + GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); + + // Check format whether have been set + int idx = peer_in_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", + (peer_in_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } + auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast(idx)); + if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { + auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); + if (dim_num == 0) { + GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); + continue; + } + /// Check whether node to change dims () + /// Because some node will calculate with 5D, C dim maybe multi meaning + auto peer_in_data_node_type = peer_in_data_node->GetType(); + auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); + // 4 means dims num + if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { + GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); + continue; + } + + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast(idx), ge_tensor_desc); + + /// Because netoutput node added before infer format ,so netoutput is end condition + /// must set netoutput format , because saved result depend on format + if (peer_in_data_node_type == NETOUTPUT) { + continue; + } + + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); + status = peer_in_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_in_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } + } + } + } + } + return GRAPH_SUCCESS; +} + +void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor_points) { + for (const auto &node : anchor_points) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + continue; + } + for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { + // single op support private format set, its origin format should not be override + auto ori_format = input_desc->GetOriginFormat(); + if (input_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) { + input_desc->SetOriginFormat(input_desc->GetFormat()); + } + } + for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { + auto ori_format = output_desc->GetOriginFormat(); + if (output_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) { + output_desc->SetOriginFormat(output_desc->GetFormat()); + } + } + } +} + +graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, + std::unordered_map &node_status) { + if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { + GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), + TypeUtils::FormatToSerialString(data_format).c_str()); + return GRAPH_SUCCESS; + } + GELOGD("Enter DataNodeFormatProcess"); + std::vector uninfered_data_nodes; + // Check and renew data nodes format + for (const auto &data_node : data_nodes) { + GE_CHECK_NOTNULL(data_node); + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + auto input_desc = op_desc->MutableInputDesc(0); + auto output_desc = op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(input_desc); + GE_CHECK_NOTNULL(output_desc); + + auto curr_format = output_desc->GetOriginFormat(); + if (curr_format != FORMAT_ND) { + // Data format has been infered , continue + continue; + } + // keep data format be ND because lacking of defination when input shape num is smaller than 4 + if (input_desc->MutableShape().GetDimNum() < kDimSize4d) { + continue; + } + // Set format for un-infered data node + input_desc->SetOriginFormat(data_format); + input_desc->SetFormat(data_format); + output_desc->SetOriginFormat(data_format); + output_desc->SetFormat(data_format); + uninfered_data_nodes.push_back(data_node); + } + // Reinfer format from uninfered data nodes + for (const auto &node : uninfered_data_nodes) { + if (node == nullptr) { + continue; + } + GELOGD("data node [%s] start infer format process", node->GetName().c_str()); + auto status = AnchorProcess(node, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str()); + return GRAPH_FAILED; + } + } + GELOGD("DataNodeFormatProcess success"); + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { + GELOGI("Enter InferOrigineFormat process!"); + + // True: infered false:no-infered + std::unordered_map node_status; + std::vector anchor_points; + std::vector data_nodes; + + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "input graph is null"); + return GRAPH_FAILED; + } + // build reflection relations of boundary + (void)reflection_builder.Clear(); + auto status = reflection_builder.BuildRefRelations(*graph); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); + return GRAPH_FAILED; + } + // User set global net format + status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); + return GRAPH_FAILED; + } + // Refresh origin format of anchor point + RefreshOriginFormatOfAnchor(anchor_points); + // Infer format process + for (const auto &anchor_node : anchor_points) { + if (anchor_node == nullptr) { + continue; + } + status = AnchorProcess(anchor_node, node_status); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str()); + return GRAPH_FAILED; + } + } + /// According to discuss with sys-enginer, data node default format is ND.Its format + /// should be set by infered.But if some data-node can not be got by infer, set context's + /// format for these data nodes. + /// Notice: ignore 5D formats + auto data_format = graph->GetDataFormat(); + status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); + + (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); + + return status; +} + +bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { + bool is_graph_inferred = false; + return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); +} +} // namespace ge diff --git a/metadef/graph/format_refiner.h b/metadef/graph/format_refiner.h new file mode 100644 index 00000000..6a64bef2 --- /dev/null +++ b/metadef/graph/format_refiner.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_FORMAT_REFINER_H_ +#define COMMON_GRAPH_FORMAT_REFINER_H_ + +#include +#include +#include +#include +#include "./compute_graph.h" +#include "./external/graph/types.h" +#include "./ge_error_codes.h" + +namespace ge { +// ShapeRefiner performs shape inference for compute graphs +class FormatRefiner { + public: + static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); + + private: + static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); + static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, + std::vector &data_nodes, + std::unordered_map &node_status); + static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map &node_status); + static void RefreshOriginFormatOfAnchor(std::vector &anchor_points); + static graphStatus BackInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status); + static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, + std::unordered_map &node_status); + static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status); + static bool IsGraphInferred(const ComputeGraphPtr &graph); +}; +} // namespace ge +#endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/metadef/graph/ge_attr_define.cc b/metadef/graph/ge_attr_define.cc new file mode 100644 index 00000000..e2886c49 --- /dev/null +++ b/metadef/graph/ge_attr_define.cc @@ -0,0 +1,1123 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace ge { +// Public attribute +const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; + +const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; + +const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; + +const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; + +const std::string ATTR_NAME_NAME = "name"; + +const std::string ATTR_NAME_TYPE = "type"; + +const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; + +const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; + +const std::string ATTR_NAME_ALPHA = "alpha"; + +const std::string ATTR_NAME_BETA = "beta"; + +const std::string ATTR_NAME_PADMODE = "pad_mode"; + +const std::string ATTR_NAME_PADMODES = "padding"; + +const std::string ATTR_NAME_MODE = "mode"; + +const std::string ATTR_NAME_FILTER = "filter"; + +const std::string ATTR_NAME_BIAS = "bias"; + +const std::string ATTR_NAME_BIAS_TERM = "bias_term"; + +const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; + +const std::string ATTR_NAME_PAD = "pad"; + +const std::string ATTR_NAME_PADS = "pad"; + +const std::string ATTR_NAME_PAD_SIZE = "pad size"; + +const std::string ATTR_NAME_PAD_MODE = "pad mode"; + +const std::string ATTR_NAME_SCALE = "scale"; + +const std::string ATTR_NAME_WINDOWS = "windows"; + +const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; + +const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; + +const std::string ATTR_NAME_RELUMODE = "relu_mode"; + +const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; + +const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; + +const std::string ATTR_NAME_ALGO = "algo"; + +const std::string ATTR_NAME_FORMAT = "format"; + +const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; + +const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; + +const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; + +const std::string ATTR_NAME_LRN_K = "lrn_k"; + +const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; + +const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; + +const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; + +const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; + +const std::string ATTR_NAME_AXIS = "axis"; +const std::string ATTR_NAME_BROADCAST = "broadcast"; + +const std::string ATTR_NAME_OUTPUT = "output"; +const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; +const std::string ATTR_NAME_TIDX = "t_idx"; + +const std::string ATTR_NAME_TPADDINGS = "t_paddings"; +const std::string ATTR_IMG_H = "img_h"; +const std::string ATTR_IMG_W = "img_w"; +const std::string ATTR_NET_H = "net_h"; +const std::string ATTR_NET_W = "net_w"; + +const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; + +const std::string ATTR_NAME_MULTIPLES = "multiples"; + +const std::string ATTR_NAME_T = "T"; +const std::string ATTR_NAME_N = "N"; + +const std::string ATTR_NAME_TSHAPE = "Tshape"; +const std::string ATTR_NAME_NAN_OPT = "nan_opt"; + +const std::string ATTR_NAME_AIPP = "aipp"; +const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; + +const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; +const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; + +const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; +const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS = "_dynamic_aipp_input_dims"; +const std::string ATTR_DATA_RELATED_AIPP_MODE = "_data_related_aipp_mode"; +const std::string ATTR_DATA_AIPP_DATA_NAME_MAP = "_data_aipp_data_name_map"; + +const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED = "_graph_has_been_added"; + +const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; +const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; + +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; +const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; + +const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; +const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; + +const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; +const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; +const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; +const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; +const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; + +const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; +const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; + +const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; +const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; +const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; +const std::string ATTR_NAME_WEIGHTS = "value"; +const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; +const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; +const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; +const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; +const std::string ATTR_NAME_RTS_LABEL_NODE = "_rts_label_node"; +const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; +const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; +const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; +const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; +const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; +const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; +const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; + +const std::string ATTR_NAME_ROOT_GRAPH_ID = "_root_graph_id"; + +// Identify node connecting to input and output +const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; +const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; + +// To be deleted +const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; +const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; +const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; +const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; +const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; +const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; +const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; +const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; +const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; +const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; + +// Refinedet +const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; + +const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; +const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; +const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; +const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; + + +// _Arg +const std::string ATTR_NAME_INDEX = "index"; +// _RetVal +const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; +// Data +const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; + +// Send +const std::string SEND_ATTR_EVENT_ID = "event_id"; + +// Recv +const std::string RECV_ATTR_EVENT_ID = "event_id"; + +// convolution +const std::string ATTR_NAME_COEF = "coef"; + +const std::string ATTR_NAME_STRIDE = "stride"; + +const std::string ATTR_NAME_STRIDES = "stride"; + +const std::string ATTR_NAME_DILATION = "dilation"; + +const std::string ATTR_NAME_DILATIONS = "dilation"; + +const std::string CONV_ATTR_NAME_MODE = "mode"; + +const std::string CONV_ATTR_NAME_ALGO = "algo"; + +const std::string CONV_ATTR_NAME_GROUP = "group"; + +const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; + +const std::string CONV_ATTR_NAME_PAD = "pad"; + +const std::string CONV_ATTR_NAME_STRIDE = "stride"; + +const std::string CONV_ATTR_NAME_DILATION = "dilation"; + +const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; + +const std::string CONV_ATTR_NAME_KERNEL = "kernel"; + +const std::string CONV_ATTR_NAME_FILTER = "filter"; + +const std::string CONV_ATTR_NAME_BIAS = "bias"; + +const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; + +const std::string CONV_ATTR_NAME_ADJ = "adj"; + +const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; + +const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; + +const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; + +const std::string NEED_INFER = "isNeedInfer"; + +// Pooling +const std::string POOLING_ATTR_MODE = "mode"; +const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; +const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; +const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; +const std::string POOLING_ATTR_WINDOW = "window"; +const std::string POOLING_ATTR_PAD = "pad"; +const std::string POOLING_ATTR_STRIDE = "stride"; +const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; +const std::string POOLING_ATTR_DATA_MODE = "data_mode"; +const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; +const std::string POOLING_ATTR_NAME_ALGO = "algo"; + +// Eltwise +const std::string ELTWISE_ATTR_MODE = "mode"; +const std::string ELTWISE_ATTR_COEFF = "coeff"; +const std::string ELTWISE_ATTR_WEIGHT = "weight"; +const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; +const std::string ELTWISE_ATTR_ALPHA = "alpha"; +const std::string ELTWISE_ATTR_BETA = "beta"; + +// BatchNorm +const std::string BATCHNORM_ATTR_MODE = "mode"; +const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; +const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; +const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; +const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; +const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; +const std::string BATCHNORM_ATTR_SCALE = "scale"; +const std::string BATCHNORM_ATTR_BIAS = "bias"; +const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; +const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; +const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; + +// huberloss +const std::string HUBER_LOSS_ATTR_DELTA = "delta"; + +// SSDRealDivTileMul +const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; + +// SSDSumMulRealDivMean +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; + +// ConcatFive2Four +// ConcatFour2Five +const std::string SSD_BOX_TYPE_NUM = "box_type_num"; +const std::string SSD_CLASS_NUM = "class_num"; +const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; +const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; +const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; +const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; + +// Scale +const std::string SCALE_ATTR_SCALE = "scale"; +const std::string SCALE_ATTR_BIAS = "bias"; + +// FullConnection +const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; +const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; +const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; +const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; +const std::string FULL_ATTR_NAME_ALGO = "algo"; + +// SoftmaxOpParams +const std::string SOFTMAX_ATTR_ALGO = "algo"; +const std::string SOFTMAX_ATTR_MODE = "mode"; + +// SparseSoftmaxCrossEntropy +const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; +const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; +// Attr labelSmoothing +const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; + +// ApplyMomentum +const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION = "applymomentum_is_graph_fusion"; + +// Activation +const std::string ACTIVATION_ATTR_MODE = "mode"; +const std::string ACTIVATION_ATTR_COEF = "coef"; + +// Concat +const std::string CONCAT_ATTR_NAME_AXIS = "axis"; + +// Const +const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; +const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; +const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; +const std::string CONST_ATTR_NAME_INPUT = "is_const"; + +// Roipooling +const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; +const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; +const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; +const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; +const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; +const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; + +// DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; +const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; +const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; +const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; +const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; +const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; +const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; +const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; +// Ssd DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; +const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; +const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; +const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; +const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; +const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; +// Refinedet DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; +// yolo DetectionOutput +const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; +const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; +const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; +const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; +const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; +const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; +const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; +const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; +const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; + +// DetectionPostprocess +const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; +const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; +const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; +const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; +const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; + +// Spatialtransfrom +const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; +const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; +const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; +const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; + +// Proposa +const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; +const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; +const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; +const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; +const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; +const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; +const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; +const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; +const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; +const std::string PROPOSAL_ATTR_IMG_H = "img_h"; +const std::string PROPOSAL_ATTR_IMG_W = "img_w"; +// Softmax +const std::string SOFTMAX_ATTR_AXIS = "axis"; + +// Permute +const std::string PERMUTE_ATTR_ORDER = "order"; +const std::string PERMUTE_ATTR_PERM = "perm"; + +// SSD Normalize +const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; +const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; +const std::string SSDNORMALIZE_ATTR_EPS = "eps"; + +// Flatten +const std::string FLATTEN_ATTR_AXIS = "axis"; +const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; + +// SsdPRIORBOX +const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; +const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; +const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; +const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; +const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; +const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; +const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; +const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; +const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; +const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; +const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; +const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; +const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; +const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; +const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; + +// RefinedetDetectionOutput +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; + +// PRelu +const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; + +// Psroi pooling +const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; +const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; +const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; + +// Power +const std::string POWER_ATTR_NAME_POWER = "power"; +const std::string POWER_ATTR_NAME_SCALE = "scale"; +const std::string POWER_ATTR_NAME_SHIFT = "shift"; + +// log +const std::string LOG_ATTR_NAME_SCALE = "scale"; +const std::string LOG_ATTR_NAME_SHIFT = "shift"; +const std::string LOG_ATTR_NAME_BASE = "base"; +// Pack +const std::string PACK_ATTR_NAME_NUM = "N"; + +// Unpack +const std::string UNPACK_ATTR_NAME_NUM = "num"; +const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; +// Gathernd +const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; +const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; + +// Argmax +const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; +const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; +const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; +const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; +const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; +const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; +const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; + +// upsample +const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; +const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; + +// Relu +const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; + +// FreeSpaceExtract +const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; + +// Split +const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; +const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; +const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; + +// Tvm +const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; +const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; +const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; +const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; + +// Squeeze +const std::string SQUEEZE_ATTR_AXIS = "axis"; +const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; +const std::string SQUEEZE_OP_NAME = "Squeeze"; + +// Stride slice +const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; +const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; +const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; +const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; +const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; + +// Slice +const std::string SLICE_ATTR_NAME_BEGINS = "begins"; +const std::string SLICE_ATTR_NAME_SIZES = "sizes"; + +// Roialign +const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; +const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; +const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; +const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; + +// Generate_rpn_proposal +const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; +const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; +// Decode_bbox +const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; + +// Cast +const std::string CAST_ATTR_DSTT = "DstT"; +const std::string CAST_ATTR_SRCT = "SrcT"; +const std::string CAST_ATTR_DST_TYPE = "dst_type"; +const std::string CAST_ATTR_TRUNCATE = "truncate"; + +// Fastrcnnn predications +const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; +const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; +const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; +const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; + +// REORG +const std::string REORG_ATTR_STRIDE = "stride"; +const std::string REORG_ATTR_REVERSE = "reverse"; + +// MERGE +const std::string MERGE_DEAD_INDEX = "merge_dead_index"; +const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; +const std::string TO_BE_OUTPUT = "to_be_output"; + +// ENTER +const std::string ENTER_ATTR_FRAME_NAME = "frame_name"; +const std::string ENTER_ATTR_CONSTANT_FLAG = "is_constant"; + +// Concatv2 +const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; +const std::string CONCAT_V2_ATTR_N = "N"; +// SUM +const std::string SUM_ATTR_TIDX = "Tidx"; +const std::string SUM_ATTR_AXIS = "axis"; +const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; + +// ResizeBilinear +const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; +const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; +const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; +const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; +const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; +const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; +const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; +const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; +const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; +const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; + +// RetinaNet +const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; +const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; + +// MatMul +const std::string MATMUL_TRANSPOSE_X = "transposeX"; +const std::string MATMUL_TRANSPOSE_W = "transposeW"; +const std::string MATMUL_HAS_BIAS = "has_bias"; +const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; + +// Flatten +const std::string FLATTEN_START_AXIS = "start_axis"; +const std::string FLATTEN_END_AXIS = "end_axis"; + +// Reshape +const std::string RESHAPE_ATTR_AXIS = "axis"; +const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; +const std::string RESHAPE_ATTR_FORMAT = "format"; +const std::string RESHAPE_ATTR_SHAPE = "shape"; +const std::string RESHAPE_ATTR_ALPHA = "alpha"; +const std::string RESHAPE_ATTR_BETA = "beta"; + +// Frameoworkop +const std::string T_IN_DATATYPE = "t_in_datatype"; +const std::string T_OUT_DATATYPE = "t_out_datatype"; +const std::string ATTR_NAME_OUT_N = "out_n"; +const std::string ATTR_NAME_OUT_C = "out_c"; +const std::string ATTR_NAME_OUT_H = "out_h"; +const std::string ATTR_NAME_OUT_W = "out_w"; +const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; +const std::string ATTR_PAD_CONV = "pad_conv"; + +const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; +const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; +const std::string PAD_ATTR_PADDINGDS = "paddings"; +const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; + +// ConvGradFilter +const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; +// ConvGradInput +const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; + +// Rnn +const std::string RNN_MODE_STATIC = "rnn_static"; +const std::string MUTI_RNN = "multi_rnn"; +const std::string CNN_RNN = "cnn_rnn"; +const std::string RNN_MODE_ = "rnn_"; + + +const std::string CELL_MODE = "mode"; +const std::string LSTM_CELL = "lstm_cell"; +const std::string GRU_CELL = "gru_cell"; +const std::string RNN_HT = "ht"; +const std::string RNN_XT_HT = "xt_ht"; +const std::string RNN_BATCH_SIZE = "batch_size"; +const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; +const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; +const std::string LSTM_ACTIVATE = "lstm_activate"; +const std::string LSTM_OUT_MAP = "lstm_out_map"; +const std::string LSTM_OUT_MODE = "lstm_out_mode"; +const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; +const std::string LSTM_TIME_MAJOR = "lstm_time_major"; +const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; + +// Upsample +const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; + +// PadV2 +const std::string PADV2_ATTR_NAME_MODE = "mode"; +const std::string PADV2_ATTR_NAME_PADS = "paddings"; +const std::string PADV2_ATTR_NAME_T = "T"; +const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; + +// MirrorPad +const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; +const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; +const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; +// Filler +const std::string FILLER_TYPE = "filler_type"; +const std::string FILLER_VALUE = "filler_value"; + +// Shufflechannel +const std::string SHUFFLE_CHANNEL_GROUP = "group"; + +// TopKV2 +const std::string TOPKV2_ATTR_K = "k"; + +// Calibaration +const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; +const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; +const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; +const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; +const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; +const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; +const std::string QUANTIZE_ALGO_ATTR = "quantize_algo"; +const std::string SCALE_TYPE_ATTR = "scale_type"; + +const std::string QUANTIZE_SCALE_MODE = "quantize_scale_mode"; +const std::string QUANTIZE_SCALE_VALUE = "quantize_scale_value"; +const std::string QUANTIZE_SCALE_OFFSET = "quantize_scale_offset"; +const std::string QUANTIZE_OFFSET_DATA_VALUE = "quantize_offset_data_value"; +const std::string QUANTIZE_OFFSET_DATA_OFFSET = "quantize_offset_data_offset"; +const std::string QUANTIZE_OFFSET_WEIGHT_VALUE = "quantize_offset_weight_value"; +const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET = "quantize_offset_weight_offset"; +const std::string QUANTIZE_OFFSET_PAD_VALUE = "quantize_offset_pad_value"; +const std::string QUANTIZE_OFFSET_PAD_OFFSET = "quantize_offset_pad_offset"; + +const std::string DEQUANTIZE_SCALE_MODE = "dequantize_scale_mode"; +const std::string DEQUANTIZE_SCALE_VALUE = "dequantize_scale_value"; +const std::string DEQUANTIZE_SCALE_OFFSET = "dequantize_scale_offset"; +const std::string DEQUANTIZE_OFFSET_DATA_TYPE = "dequantize_offset_data_value"; +const std::string DEQUANTIZE_OFFSET_DATA_OFFSET = "dequantize_offset_data_offset"; +const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE = "dequantize_offset_weight_value"; +const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET = "dequantize_offset_weight_offset"; +const std::string DEQUANTIZE_OFFSET_PAD_VALUE = "dequantize_offset_pad_value"; +const std::string DEQUANTIZE_OFFSET_PAD_OFFSET = "dequantize_offset_pad_offset"; + +const std::string REQUANTIZE_SCALE_MODE = "requantize_scale_mode"; +const std::string REQUANTIZE_SCALE_VALUE = "requantize_scale_value"; +const std::string REQUANTIZE_SCALE_OFFSET = "requantize_scale_offset"; +const std::string REQUANTIZE_OFFSET_DATA_VALUE = "requantize_offset_data_value"; +const std::string REQUANTIZE_OFFSET_DATA_OFFSET = "requantize_offset_data_offset"; +const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE = "requantize_offset_weight_value"; +const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET = "requantize_offset_weight_offset"; +const std::string REQUANTIZE_OFFSET_PAD_VALUE = "requantize_offset_pad_value"; +const std::string REQUANTIZE_OFFSET_PAD_OFFSET = "requantize_offset_pad_offset"; + +const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; + +const std::string ATTR_NAME_GROUP = "group"; +const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; +const std::string ATTR_NAME_EPSILON = "epsilon"; +const std::string ATTR_NAME_POOLING_MODE = "mode"; +const std::string ATTR_NAME_CLASS_NUM = "class_num"; +// model +const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; + +const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; + +const std::string ATTR_MODEL_EVENT_NUM = "event_num"; + +const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; + +const std::string ATTR_MODEL_LABEL_NUM = "label_num"; + +const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; + +const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; + +const std::string ATTR_MODEL_P2P_MEMORY_SIZE = "p2p_memory_size"; + +const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; + +const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; + +const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; + +const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; + +const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR = "task_gen_variable_addr"; + +const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; + +const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; + +const std::string ATTR_MODEL_CORE_TYPE = "core_type"; + +const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; + +const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; + +// Public attribute +const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; + +const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; + +const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; + +const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; + +const std::string ATTR_NAME_IO_OP = "io_op"; + +const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; + +const std::string ATTR_NAME_OPATTR = "opattr"; + +const std::string ATTR_NAME_RELUFLAG = "relu_flag"; + +const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; + +const std::string ATTR_NAME_X_INDEX = "x_index"; + +const std::string ATTR_NAME_CONT_INDEX = "cont_index"; + +const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; + +const std::string TARGET_TYPE_MINI = "MINI"; + +const std::string TARGET_TYPE_TINY = "TINY"; + +const std::string TARGET_TYPE_LITE = "LITE"; + +// l2_normalize +const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; +const std::string L2_NORMALIZE_ATTR_EPS = "eps"; + +const std::string POOL_PARAMA_ATTR_WINDOW = "window"; +const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; +const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; +const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; +const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; +const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; + +// HCOM +const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; +const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; + +const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; +const std::string HCOM_ATTR_GROUP = "group"; +const std::string HCOM_ATTR_SR_TAG = "sr_tag"; +const std::string HCOM_ATTR_SRC_RANK = "src_rank"; +const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; +const std::string HCOM_ATTR_FUSION = "fusion"; +const std::string HCOM_ATTR_SHAPE = "shape"; +const std::string HCOM_ATTR_DATA_TYPE = "dtype"; + +// SpaceToDepth/DepthToSpace +const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; + +// SparseSoftmaxCrossEntropyWithLogits +const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; + +// MaxPoolGradWithArgmax +const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; + +// AvgPoolGrad +const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; + +// Pad +const std::string ATTR_PAD_FORMAT = "attr_pad_format"; + +// Varible +const std::string VAR_ATTR_FORMAT = "_var_format"; +const std::string VAR_ATTR_NAME = "var_name"; +const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; +const std::string VAR_ATTR_4D_FORMAT = "4D"; +const std::string VAR_ATTR_5D_FORMAT = "5D"; +const std::string VAR_ATTR_DATA_TYPE = "data_format"; +const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; +const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; +const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; +const std::string VAR_ATTR_SHAPE = "shape"; +const std::string HALF_VAR_NAME_END = "_fp16"; +const std::string VAR_ATTR_INITED = "var_is_inited"; + +const std::string VAR_ATTR_CONTAINER = "container"; +const std::string VAR_ATTR_SHARED_NAME = "shared_name"; +const std::string VAR_ATTR_DTYPE = "dtype"; + +const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; +const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; +const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; +const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; +const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; +const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; + +// Assign +const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; +const std::string ASSIGN_VAR_NAME = "_assign_var_name"; + +// space2bacth batch2space +const std::string BATCH_SPACE_ATTR_BLOCK = "block"; +const std::string BATCH_SPACE_ATTR_PADDING = "padding"; + +// depth_to_space space_to_depth +const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; + +// FakeQuantWithMinMaxVars +const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; +const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; + +// mobilenet_ssd_conv_fusion +const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; +const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; +const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; + +// lsh project +const std::string LSH_PROJ_TYPE = "lsh_project_type"; + +// log time stamp +const std::string LOG_TIME_STAMP_LOGID = "logid"; +const std::string LOG_TIME_STAMP_NOTIFY = "notify"; + +// ShapeN +const std::string SHAPEN_ATTR_N = "N"; +const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; +const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; + +// GatherV2 attr def +const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; +const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; +const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; + +// Reshape attr def +const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; +const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; + +// axis attr def +const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; + +const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; + +const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; +const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; + +// For constant folding +const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; + +const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; + +const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; + +const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; + +const std::string ATTR_NAME_REFERENCE = "reference"; + +const std::string ATTR_NAME_NOTASK = "_no_task"; + +const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; + +const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; + +const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; + +const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; + +const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; + +// Used for mark the active label list stream of activated node +const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; + +// Used for l2cache, true: the memory of all inputs is used for the last time. +const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; + +// Multi batch +const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; +const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; +const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; +const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; + +// Control flow +const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; +const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; +const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; +const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; +const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; +const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; +const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; +const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; + +const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; +const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; +const std::string ATTR_NAME_SWITCH_DATA_TYPE = "_switch_data_type"; +const std::string ATTR_NAME_ORIG_NODE_NAME = "_original_node_name"; +const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; + +const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; + +// Function Op +const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; + +// Used for mark the active node is for loop, type:bool +const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; + +const std::string ATTR_NAME_MEMORY_TYPE_INPUT = "memory_type_input"; + +const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; + +const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; + +const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; + +const std::string MODEL_ATTR_SESSION_ID = "session_id"; + +// lx fusion +const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; +const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; +const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; +const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; +const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; +const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; +const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; +const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; +const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; +const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; +const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; +const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; +const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; +const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; +const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; +const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; +const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; +const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; +const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; +const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; +const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; +const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; +const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; +const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; +const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; +const std::string ATTR_NAME_DATA_SLICE = "_data_slice"; +const std::string ATTR_NAME_NEED_RECOVER_ATTR = "_need_recover_attr"; + +// used for memory allocate +const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; +const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; +const std::string ATTR_NAME_WORKSPACE_TYPE_LIST = "_workspace_type"; +const std::string ATTR_NAME_TENSOR_MEM_TYPE = "_tensor_memory_type"; + +// Op debug attrs +const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; +const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; + +// Atomic addr clean attrs +const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; +const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; +const std::string ATOMIC_ATTR_IS_FUSION_NODE = "is_fusion_node"; +const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO = "sub_node_workspace_info"; +const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET = "sub_node_workspace_offset"; +const std::string ATOMIC_ATTR_IS_ATOMIC_NODE = "is_atomic_node"; + +// Source/dst format for Op FormatTransfer +const std::string FORMAT_TRANSFER_SRC_FORMAT = "src_format"; +const std::string FORMAT_TRANSFER_DST_FORMAT = "dst_format"; + +// For compile op by ge call +const std::string ATTR_NEED_COMPILE = "_node_need_compile"; + +const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; + +const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; + +const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; + +const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; + +// For inserted op +const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; + +// For compress weight +const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; + +// For data dump +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; +const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; +const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX = "_datadump_sub_spliter_index"; +const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME = "_datadump_group_op_name"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME = "_datadump_origin_name"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_output_index"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; +const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; + +// functional ops attr +const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; +const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; +const std::string ATTR_NAME_WHILE_COND = "cond"; +const std::string ATTR_NAME_WHILE_BODY = "body"; + +// used for label switch +const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; +const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; +const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; + +const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; +const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; + +// used for LX tiling +const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; +const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; +const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; +const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; +const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; +const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; + +// for unregistered op +const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; +const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; + +// used for Horovod +const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; +const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; +// used for allreduce tailing optimization +const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; +const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; + +// dynamic shape attr +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; +const std::string ATTR_DYNAMIC_SHAPE_SINGLE_AICPU = "_single_aicpu_dynamic"; + +// op dynamic input +const std::string ATTR_NAME_DYNAMIC_INPUT_START = "_dynamic_input_index_start"; +const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end"; + +// atc user def dtype&format +const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; +const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; + +// atc user def dtype&format +const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES = "_user_defined_output_nodes"; + +// for fusion op plugin +const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; + +// graph partition for aicpu +const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; +const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; + +// input and output memory type +const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; +const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; +const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; +const std::string ATTR_NAME_SPECIAL_OUTPUT_SIZE = "_special_output_size"; + +// stage +const std::string ATTR_STAGE_LEVEL = "_stage_level"; + +// input_output_offset +const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; +const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; + +// The processing mode of INF and NAN during floating-point number calculation. +const std::string ATTR_FP_CEILING_MODE = "_fp_ceiling_mode"; +// count of data from getnext_sink +const std::string ATTR_GETNEXT_SINK_DATA_COUNT = "N"; +const std::string ATTR_GETNEXT_SINK_SHAPE_INFO = "shape_info"; + +// getnext_sink marked on NetOutput +const std::string ATTR_GETNEXT_SINK_DYNMAIC = "getnext_sink_dynamic"; +const std::string ATTR_ALL_GEARS_INFO = "all_gears_info"; + +// Calculate the operator output memory +const std::string ATTR_NAME_MEMORY_SIZE_CALC_TYPE = "_memory_size_calc_type"; +} // namespace ge diff --git a/metadef/graph/ge_attr_value.cc b/metadef/graph/ge_attr_value.cc new file mode 100644 index 00000000..ff9ffacc --- /dev/null +++ b/metadef/graph/ge_attr_value.cc @@ -0,0 +1,1289 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "external/graph/graph.h" +#include "utils/attr_utils.h" +#include "framework/common/debug/ge_log.h" +#include "graph/model_serialize.h" +#include "proto/ge_ir.pb.h" +#include "detail/model_serialize_imp.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" + +using std::map; +using std::string; +using std::vector; + +namespace ge { +NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } + +NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) + : named_attrs_(owner, proto_msg) {} // lint !e1744 + +void NamedAttrs::SetName(const std::string &name) { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_name(name); + } +} + +string NamedAttrs::GetName() const { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->name(); + } + return string(); +} + +GeAttrValue NamedAttrs::GetItem(const string &key) const { + GeAttrValue value; + (void)GetAttr(key, value); + return value; +} + +ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); + } + return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); +} + +ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { + auto proto_msg = named_attrs_.GetProtoMsg(); + if (proto_msg != nullptr) { + return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); + } + return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); +} + +class GeAttrValueImp { + public: + static map attr_val_one_type_map_; + static map attr_val_list_type_map_; + + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); + static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); + static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); + static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); + static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); + static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); + static bool SetValue(proto::AttrDef &attr_def, const vector &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); + static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); + + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::TENSOR_DESC &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::NAMED_ATTRS &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_INT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_FLOAT &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_BOOL &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_STR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_TENSOR &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_TENSOR_DESC &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_BYTES &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_NAMED_ATTRS &val); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + GeAttrValue::LIST_GRAPH &val); + // Value will be moved + static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); + static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); + // Value will be moved + static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &list_buffer); + static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &list_buffer); + + static bool SetValue(proto::AttrDef &attr_def, const vector> &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector> &value); + static bool SetValue(proto::AttrDef &attr_def, const vector &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, + vector &value); + + static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value); + static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value); +}; + +map GeAttrValueImp::attr_val_one_type_map_ = { + {proto::AttrDef::kI, GeAttrValue::VT_INT}, + {proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, + {proto::AttrDef::kB, GeAttrValue::VT_BOOL}, + {proto::AttrDef::kS, GeAttrValue::VT_STRING}, + {proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, + {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, + {proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, + {proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, + {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, + {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, + {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, +}; +map GeAttrValueImp::attr_val_list_type_map_ = { + {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, + {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } + +GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {} + +GeAttrValue::ValueType GeAttrValue::GetValueType() const { + auto proto_msg = value_.GetProtoMsg(); + if (proto_msg != nullptr) { + auto val_case = proto_msg->value_case(); + if (val_case != proto::AttrDef::kList) { + auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case); + if (it != GeAttrValueImp::attr_val_one_type_map_.end()) { + return it->second; + } + } else { + auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type()); + if (it != GeAttrValueImp::attr_val_list_type_map_.end()) { + return it->second; + } + } + } + return GeAttrValue::VT_NONE; +} + +bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; } + +GeAttrValue GeAttrValue::Copy() const { + GeAttrValue valueRet; + auto proto_msg = value_.GetProtoMsg(); + auto proto_msg_ret = valueRet.value_.GetProtoMsg(); + if (proto_msg != nullptr && proto_msg_ret != nullptr) { + *proto_msg_ret = *proto_msg; + } + return valueRet; +} + +#define ATTR_VALUE_SET_GET_IMP(type) \ + graphStatus GeAttrValue::SetValue(const type &val) { \ + auto proto_msg = value_.GetProtoMsg(); \ + if (proto_msg) { \ + if (GeAttrValueImp::SetValue(*proto_msg, val)) { \ + return GRAPH_SUCCESS; \ + } \ + } \ + return GRAPH_FAILED; \ + } \ + \ + graphStatus GeAttrValue::GetValue(type &val) const { \ + auto proto_msg = value_.GetProtoMsg(); \ + if (proto_msg) { \ + if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \ + return GRAPH_SUCCESS; \ + } \ + } \ + return GRAPH_FAILED; \ + } + +ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) +ATTR_VALUE_SET_GET_IMP(vector) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) +ATTR_VALUE_SET_GET_IMP(vector) +/*lint -e665*/ +ATTR_VALUE_SET_GET_IMP(vector>) +/*lint +e665*/ +ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 +ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 + +#undef ATTR_VALUE_SET_GET_IMP + +graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); } + +graphStatus GeAttrValue::MutableListTensor(vector &list_tensor) { return GetValue(list_tensor); } + +class AttrUtilsHelper { + public: + inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { + if (attr_def.value_case() != proto_case) { + GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); + return false; + } + return true; + } + + inline static bool GetValueCheckListType( + const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, + const std::function item_check_fun) { + if (attr_def.value_case() != proto::AttrDef::kList) { + GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); + return false; + } + auto &list = attr_def.list(); + if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) { + return item_check_fun(attr_def); + } + if (list.val_type() != proto_list_case) { + GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case); + return false; + } + return true; + } + + inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { + if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) { + GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); + return false; + } + return true; + } + + inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def, + proto::AttrDef_ListValue_ListValueType proto_list_case) { + if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) { + GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case()); + return false; + } + auto list = attr_def.mutable_list(); + if (list == nullptr) { + GELOGE(GRAPH_FAILED, "list is nullptr"); + return false; + } + if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE && + list->val_type() != proto_list_case) { + GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast(list->val_type()), + static_cast(proto_list_case)); + return false; + } + list->set_val_type(proto_list_case); + return true; + } + + static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) { + if (obj == nullptr) { + GELOGE(FAILED, "%s obj is nullptr", name.c_str()); + return false; + } + auto attr_map = obj->GetAttrMap().GetProtoMsg(); + if (attr_map == nullptr) { + GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); + return false; + } + auto it = attr_map->find(name); + if (it == attr_map->end()) { + return false; + } + attr_def = &it->second; + return true; + } + + inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) { + if (obj == nullptr) { + GELOGE(FAILED, " %s obj is nullptr", name.c_str()); + return false; + } + auto attr_map = obj->MutableAttrMap().GetProtoMsg(); + if (attr_map == nullptr) { + GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); + return false; + } + // Get or add + attr_def = &((*attr_map)[name]); + return true; + } +}; + +#define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \ + bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ + return false; \ + } \ + proto_attr_val.set_##protoItem(value); \ + return true; \ + } + +#define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \ + bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \ + proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \ + return false; \ + } \ + auto list = proto_attr_val.mutable_list(); \ + list->clear_##protoItem(); \ + for (const auto &item : value) { \ + list->add_##protoItem(item); \ + } \ + return true; \ + } + +ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i) +ATTR_VALUE_IMP_SET_ONE(float, kF, f) +ATTR_VALUE_IMP_SET_ONE(const string &, kS, s) +ATTR_VALUE_IMP_SET_ONE(bool, kB, b) + +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_FLOAT, f) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_STRING, s) +ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_BOOL, b) + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { + return false; + } + auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_attr_val.mutable_td() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_td(); + for (const auto &item : value) { + auto proto_msg = item.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + proto_attr_val.clear_list(); + return false; + } + *list->add_td() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) { + if (value) { + return SetValue(proto_attr_val, *value); + } else { + return SetValue(proto_attr_val, GeTensor()); + } +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { + return false; + } + auto proto_msg = val.tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + return false; + } + *proto_attr_val.mutable_t() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + vector constList(value.size()); + std::copy(value.begin(), value.end(), constList.begin()); + return SetValue(proto_attr_val, constList); +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_t(); + for (const auto &item : value) { + if (item == nullptr) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr"); + proto_attr_val.clear_list(); + return false; + } + auto proto_msg = item->tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + proto_attr_val.clear_list(); + return false; + } + *list->add_t() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_t(); + for (const auto &item : value) { + auto proto_msg = item.tensor_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + proto_attr_val.clear_list(); + return false; + } + *list->add_t() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + size_t val_size = value.GetSize(); + proto_attr_val.set_bt(value.GetData(), val_size); + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_bt(); + for (const auto &item : value) { + list->add_bt(item.GetData(), item.GetSize()); + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { + return false; + } + auto proto_msg = value.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + GELOGE(FAILED, "Proto msg is nullptr"); + return false; + } + *proto_attr_val.mutable_func() = *proto_msg; + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_na(); + for (const auto &item : value) { + auto proto_msg = item.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + proto_attr_val.clear_list(); + return false; + } + *list->add_na() = *proto_msg; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { + return false; + } + ModelSerializeImp imp; + if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed"); + proto_attr_val.clear_g(); + return false; + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_g(); + + ModelSerializeImp imp; + for (const auto &item : value) { + if (!imp.SerializeGraph(item, list->add_g())) { + GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph"); + proto_attr_val.clear_list(); + return false; + } + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { + return false; + } + proto_attr_val.clear_list_list_int(); + auto list_list_int = proto_attr_val.mutable_list_list_int(); + GE_CHECK_NOTNULL_EXEC(list_list_int, return false); + for (auto &list_int : value) { + auto list_item = list_list_int->add_list_list_i(); + GE_CHECK_NOTNULL_EXEC(list_item, return false); + for (auto &int_item : list_int) { + list_item->add_list_i(int_item); + } + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_dt(); + for (const auto &item : value) { + list->add_dt(static_cast(item)); + } + return true; +} + +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { + return false; + } + proto_attr_val.set_dt(static_cast(value)); + + return true; +} + +#define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \ + bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \ + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ + return false; \ + } \ + value = proto_attr_val.protoItem(); \ + return true; \ + } + +#define ListValueItemCheck(protoItem) \ + [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; } + +#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ + bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector &value) { \ + value.clear(); \ + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, \ + proto::AttrDef_ListValue_ListValueType_##proto_list_case, \ + ListValueItemCheck(protoItem))) { \ + return false; \ + } \ + auto &list = proto_attr_val.list(); \ + for (const auto &item : list.protoItem()) { \ + value.push_back(item); \ + } \ + return true; \ + } + +ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i) +ATTR_VALUE_IMP_GET_ONE(float &, kF, f) +ATTR_VALUE_IMP_GET_ONE(string &, kS, s) +ATTR_VALUE_IMP_GET_ONE(bool &, kB, b) + +ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i) +ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f) +ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s) +ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b) + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { + return false; + } + auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = proto_attr_val.td(); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + if (!AttrUtilsHelper::GetValueCheckListType( + proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.td()) { + value.emplace_back(GeTensorDesc()); + auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = item; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + GeTensorPtr &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { + return false; + } + value = std::shared_ptr( + new (std::nothrow) GeTensor(proto_owner, const_cast(proto_attr_val).mutable_t())); + GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, + ListValueItemCheck(t))) { + return false; + } + auto list = const_cast(proto_attr_val).mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + for (auto &item : *(list->mutable_t())) { + std::shared_ptr temp_value = std::shared_ptr(new (std::nothrow) GeTensor(proto_owner, &item)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + value.push_back(temp_value); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + auto &proto_val = proto_attr_val.bt(); + GE_LOGI_IF(proto_val.size() == 0, "size res is 0."); + value = Buffer::CopyFrom(reinterpret_cast(proto_val.data()), proto_val.size()); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, + ListValueItemCheck(bt))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.bt()) { + value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size())); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + GeAttrValue::NAMED_ATTRS &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { + return false; + } + auto proto_msg = value.named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = proto_attr_val.func(); + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType( + proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.na()) { + value.emplace_back(GeAttrValue::NAMED_ATTRS()); + if (value.empty()) { + return false; + } + auto proto_msg = value.back().named_attrs_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + *proto_msg = item; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { + return false; + } + ComputeGraphPtr graph = nullptr; + std::shared_ptr graph_def; + graph_def = ComGraphMakeShared(proto_attr_val.g()); + if (graph_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + graph_def = nullptr; + return false; // lint !e665 + } else { + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_def); + if (!imp.UnserializeGraph(graph, *graph_def)) { + GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); + return false; + } // lint !e514 + value = graph; + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, + ListValueItemCheck(g))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.g()) { + std::shared_ptr graph_def; + graph_def = ComGraphMakeShared(item); + if (graph_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + graph_def = nullptr; + return false; // lint !e665 + } else { + ComputeGraphPtr graph = nullptr; + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_def); + if (!imp.UnserializeGraph(graph, *graph_def)) { + GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); + return false; + } // lint !e514 + value.push_back(graph); + } + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector> &value) { + value.clear(); + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { + return false; + } + + auto &list_listint = proto_attr_val.list_list_int().list_list_i(); + for (auto &list_int : list_listint) { + vector list_item(list_int.list_i().size()); + if (!list_int.list_i().empty()) { + (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin()); + } + value.push_back(list_item); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &value) { + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, + ListValueItemCheck(dt))) { + return false; + } + auto &list = proto_attr_val.list(); + for (const auto &item : list.dt()) { + value.emplace_back(static_cast(item)); + } + return true; +} + +bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { + return false; + } + value = static_cast(proto_attr_val.dt()); + return true; +} + +GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + Buffer &&buffer) { + if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + auto proto_msg = buffer.data_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt())); + return true; +} + +bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + Buffer &buffer) { + if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { + return false; + } + buffer = Buffer(proto_owner, &const_cast(proto_attr_val)); + return true; +} + +bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, + vector &list_buffer) { + if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, + proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { + return false; + } + auto list = proto_attr_val.mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + list->clear_bt(); + for (auto &item : list_buffer) { + auto proto_msg = item.data_.GetProtoMsg(); + if (proto_msg == nullptr) { + return false; + } + list->add_bt(std::move(*proto_msg->mutable_bt())); + } + return true; +} + +bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, + vector &list_buffer) { + list_buffer.clear(); + if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, + ListValueItemCheck(bt))) { + return false; + } + auto list = const_cast(proto_attr_val).mutable_list(); + GE_CHECK_NOTNULL_EXEC(list, return false); + for (auto &item : *(list->mutable_bt())) { + list_buffer.emplace_back(Buffer(proto_owner, &item)); + } + return true; +} + +bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { + if (!obj) { + return false; + } + return obj->HasAttr(name); +} + +#define ATTR_UTILS_SET_IMP(FuncName, Type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ + AttrHolderAdapter &&obj, const string &name, const Type &value) { \ + proto::AttrDef *proto_attr_val = nullptr; \ + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ + return false; \ + } \ + if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \ + GELOGW("Set" #FuncName " failed key %s", name.c_str()); \ + return false; \ + } \ + return true; \ + } + +#define ATTR_UTILS_GET_IMP(FuncName, Type) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \ + const string &name, Type &value) { \ + const proto::AttrDef *proto_attr_val = nullptr; \ + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ + return false; \ + } \ + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \ + GELOGW("Get" #FuncName " failed key %s", name.c_str()); \ + return false; \ + } \ + return true; \ + } + +#define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \ + ATTR_UTILS_SET_IMP(FuncName, Type) \ + ATTR_UTILS_GET_IMP(FuncName, Type) + +ATTR_UTILS_SET_GET_IMP(Int, int64_t) +ATTR_UTILS_SET_GET_IMP(Float, float) +ATTR_UTILS_SET_GET_IMP(Bool, bool) +ATTR_UTILS_SET_GET_IMP(Str, string) +ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) +ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) +ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) +ATTR_UTILS_SET_IMP(Tensor, GeTensor) +ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) +ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) +ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) +/*lint -e665*/ +ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) +/*lint +e665*/ +ATTR_UTILS_SET_GET_IMP(ListInt, vector) +ATTR_UTILS_SET_IMP(ListInt, vector) +ATTR_UTILS_SET_IMP(ListInt, vector) +ATTR_UTILS_SET_GET_IMP(ListFloat, vector) +ATTR_UTILS_SET_GET_IMP(ListBool, vector) +ATTR_UTILS_SET_GET_IMP(ListStr, vector) +ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_IMP(ListTensor, vector) +ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) +ATTR_UTILS_SET_GET_IMP(ListBytes, vector) +ATTR_UTILS_SET_GET_IMP(ListGraph, vector) +ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 +ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 + +bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, + std::initializer_list &&value) { + return SetListTensor(std::move(obj), name, vector(value)); +} + +bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + GeTensorPtr tensor; + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { + return false; + } + value = tensor; + return true; +} + +bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + vector tensor; + if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { + return false; + } + value.insert(value.begin(), tensor.begin(), tensor.end()); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, + const string &name, GeTensorPtr &value) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); +} + +bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { + value.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); +} + +bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetValue(*proto_attr_val, value); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, + int32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > INT32_MAX) { + GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val); + return false; + } + value = static_cast(int64_val); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, + uint32_t &value) { + int64_t int64_val = 0; + if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { + return false; + } + if (int64_val > UINT32_MAX) { + GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val); + return false; + } + value = static_cast(int64_val); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; + } + + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > INT32_MAX) { + GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); + return false; + } + } + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, + const string &name, vector &value) { + value.clear(); + vector int64_list; + if (!GetListInt(std::move(obj), name, int64_list)) { + return false; + } + + for (size_t i = 0; i < int64_list.size(); ++i) { + if (int64_list[i] > UINT32_MAX) { + GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); + return false; + } + } + value.insert(value.begin(), int64_list.begin(), int64_list.end()); + return true; +} + +bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value) { + if (obj) { + vector bytes_vals; + for (auto &item : value) { + ModelSerialize serialize; + auto buffer = serialize.SerializeOpDesc(item); + if (buffer.GetSize() == 0) { + return false; + } + bytes_vals.push_back(buffer); + } + return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, + const string &name, + const vector &value) { + if (obj) { + vector bytes_vals; + for (auto &item : value) { + ModelSerialize serialize; + auto buffer = serialize.SerializeOpDesc(item); + if (buffer.GetSize() == 0) { + return false; + } + bytes_vals.push_back(buffer); + } + return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj, + const string &name, + vector &value) { + value.clear(); + + vector bytes_vals; + if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) { + return false; + } + for (const auto &item : bytes_vals) { + ModelSerialize serialize; + auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 + value.push_back(op_desc); + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, + const string &name, Buffer &&buffer) { + // Value will be moved + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, + const string &name, Buffer &buffer) { + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer); +} + +bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &list_buffer) { + // Value will be moved + proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); +} + +bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &list_buffer) { + list_buffer.clear(); + const proto::AttrDef *proto_attr_val = nullptr; + if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { + return false; + } + return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def; + op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; // lint !e665 + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); + op_desc->extAttrs_ = org_op_desc->extAttrs_; + + // This function may be called by some passes of fusion engine, in this condition, do not need these attribute + if (!op_desc->input_name_idx_.empty()) { + op_desc->input_name_idx_.clear(); + } + if (!op_desc->output_name_idx_.empty()) { + op_desc->output_name_idx_.clear(); + } + if (!op_desc->optional_input_names_.empty()) { + op_desc->optional_input_names_.clear(); + } + + return op_desc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { + if (org_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "org_op_desc is null"); + return nullptr; + } + std::shared_ptr op_def = ComGraphMakeShared(); + if (op_def == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; + } + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); + + op_desc->extAttrs_ = org_op_desc->extAttrs_; + + op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); + op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), + org_op_desc->optional_input_names_.end()); + op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); + + op_desc->infer_func_ = org_op_desc->infer_func_; + op_desc->infer_format_func_ = org_op_desc->infer_format_func_; + op_desc->verifier_func_ = org_op_desc->verifier_func_; + + return op_desc; +} +std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { + auto holder = obj.get(); + if (holder == nullptr) { + return ""; + } + auto attrs_map = holder->GetAttrMap(); + if (attrs_map.GetProtoMsg() == nullptr) { + return ""; + } + + std::map ordered_attrs; + for (auto &attr : *(attrs_map.GetProtoMsg())) { + ordered_attrs[attr.first] = attr.second.SerializeAsString(); + } + + std::stringstream ss; + for (auto &attr : ordered_attrs) { + ss << attr.first << ":" << attr.second << ";"; + } + return ss.str(); +} +} // namespace ge diff --git a/metadef/graph/ge_tensor.cc b/metadef/graph/ge_tensor.cc new file mode 100644 index 00000000..bc7f3b1b --- /dev/null +++ b/metadef/graph/ge_tensor.cc @@ -0,0 +1,1027 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ge_tensor.h" +#include +#include +#include +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_attr_value.h" +#include "graph/model_serialize.h" +#include "proto/ge_ir.pb.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace ge { +namespace{ +const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; +const std::map kDataTypeMap = { + {DT_UNDEFINED, proto::DT_UNDEFINED}, + {DT_FLOAT, proto::DT_FLOAT}, + {DT_FLOAT16, proto::DT_FLOAT16}, + {DT_INT8, proto::DT_INT8}, + {DT_UINT8, proto::DT_UINT8}, + {DT_INT16, proto::DT_INT16}, + {DT_UINT16, proto::DT_UINT16}, + {DT_INT32, proto::DT_INT32}, + {DT_INT64, proto::DT_INT64}, + {DT_UINT32, proto::DT_UINT32}, + {DT_UINT64, proto::DT_UINT64}, + {DT_BOOL, proto::DT_BOOL}, + {DT_DOUBLE, proto::DT_DOUBLE}, + {DT_DUAL, proto::DT_DUAL}, + {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, + {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, + {DT_COMPLEX64, proto::DT_COMPLEX64}, + {DT_COMPLEX128, proto::DT_COMPLEX128}, + {DT_QINT8, proto::DT_QINT8}, + {DT_QINT16, proto::DT_QINT16}, + {DT_QINT32, proto::DT_QINT32}, + {DT_QUINT8, proto::DT_QUINT8}, + {DT_QUINT16, proto::DT_QUINT16}, + {DT_RESOURCE, proto::DT_RESOURCE}, + {DT_STRING_REF, proto::DT_STRING_REF}, + {DT_STRING, proto::DT_STRING}, +}; + +const std::map kDataTypeSelfDefinedMap = { + {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, + {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, +}; +} + + + +GeShape::GeShape() { shape_def_.InitDefault(); } + +// Default +GeShape::GeShape(std::vector s) : GeShape() { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : s) { + proto_msg->add_dim(i); + } + } +} + +size_t GeShape::GetDimNum() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim_size() >= 0) { + // check whether contain -2, if true, return -1 + for (auto i : proto_msg->dim()) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } + return proto_msg->dim_size(); + } else { + return 0; + } + } + return 0; +} + +int64_t GeShape::GetDim(size_t idx) const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim_size() > static_cast(idx)) { + return proto_msg->dim(static_cast(idx)); + } + } + return 0; +} + +graphStatus GeShape::SetDim(size_t idx, int64_t value) { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + auto dims = proto_msg->mutable_dim(); + GE_CHECK_NOTNULL(dims); + if (dims->empty()) { + GELOGE(GRAPH_FAILED, "shape is empty"); + return GRAPH_FAILED; + } + if (static_cast(idx) >= dims->size()) { + GELOGE(GRAPH_FAILED, "idx is out of range"); + return GRAPH_FAILED; + } + proto_msg->set_dim(static_cast(idx), value); + } + return GRAPH_SUCCESS; +} + +std::vector GeShape::GetDims() const { + vector dims; + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : proto_msg->dim()) { + dims.push_back(i); + } + } + return dims; +} + +std::string GeShape::ToString() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg == nullptr) { + return ""; + } + + std::stringstream ss; + bool first = true; + for (auto i : proto_msg->dim()) { + if (first) { + first = false; + } else { + ss << ","; + } + ss << i; + } + return ss.str(); +} + +int64_t GeShape::GetShapeSize() const { + int64_t res = 1; + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (proto_msg->dim().empty()) { + return 0; + } + for (auto i : proto_msg->dim()) { + // if unknown shape, return -1 + if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { + return UNKNOWN_DIM; + } + res *= i; + } + } + return res; +} + +/// +/// @brief Check is unknown shape +/// @return bool +/// /// +bool GeShape::IsUnknownShape() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto i : proto_msg->dim()) { + if (i < 0) { + return true; + } + } + } + return false; +} + +/// +/// @brief Check is a scalar +/// @return bool +/// +bool GeShape::IsScalar() const { + auto proto_msg = shape_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->dim().empty(); + } + return false; +} + +const string TENSOR_UTILS_SIZE = "size"; +const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; +const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; +const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor"; +const string TENSOR_UTILS_DEVICE_TYPE = "device_type"; +const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor"; +const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt"; +const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index"; +const string TENSOR_UTILS_DATA_OFFSET = "data_offset"; +const string TENSOR_UTILS_CMPS_SIZE = "cmps_size"; +const string TENSOR_UTILS_CMPS_TAB = "cmps_tab"; +const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset"; +const string TENSOR_UTILS_CMPSINFO = "cmps_info"; +const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info"; +const string TENSOR_UTILS_RC = "rc"; +const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; +const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; +const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; +const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; +const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; + +GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} + +GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); } + +GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); } + +GeShape &GeShape::operator=(const GeShape &other) { + if (&other != this) { + shape_def_.CopyValueFrom(other.shape_def_); + } + return *this; +} + +GeShape &GeShape::operator=(GeShape &&other) { + if (&other != this) { + shape_def_.CopyValueFrom(std::move(other.shape_def_)); + } + return *this; +} + +GeTensorDesc::GeTensorDesc() { + tensor_descriptor_.InitDefault(); + SetDataType(DT_FLOAT); + Init(); +} + +// Default +GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() { + SetFormat(format); + SetDataType(dt); + ShapeReference() = std::move(shape); +} + +// Default +GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() { + tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); +} + +// Default +GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() { + tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_)); +} + +GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg) + : tensor_descriptor_(proto_owner, proto_msg) { + if (proto_msg != nullptr && !proto_msg->has_out_attr()) { + proto_msg->set_has_out_attr(true); + + int64_t size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size); + proto_msg->set_size(size); + + int64_t weight_size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size); + proto_msg->set_weight_size(weight_size); + + bool reuse_input = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input); + proto_msg->set_reuse_input(reuse_input); + + bool output_tensor = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor); + proto_msg->set_output_tensor(output_tensor); + + string device_type = "NPU"; + (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type); + proto_msg->set_device_type(device_type); + + bool input_tensor = false; + (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor); + proto_msg->set_input_tensor(input_tensor); + + int64_t real_dim_cnt = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt); + proto_msg->set_real_dim_cnt(real_dim_cnt); + + int64_t reuse_input_index = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index); + proto_msg->set_reuse_input_index(reuse_input_index); + + int64_t data_offset = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset); + proto_msg->set_data_offset(data_offset); + + int64_t cmps_size = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size); + proto_msg->set_cmps_size(cmps_size); + + string cmps_tab; + (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab); + proto_msg->set_cmps_tab(cmps_tab); + + int64_t cmps_tab_offset = 0; + (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset); + proto_msg->set_cmps_tab_offset(cmps_tab_offset); + } +} + +bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const { + const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg(); + const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); + if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { + // Message TensorDescriptor in ge_ir.proto + return (IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && + IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && + // Message ShapeDef in ge_ir.proto + IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), + "TensorDescriptor.shape().dim()") && + IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && + IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), + "TensorDescriptor.has_out_attr()") && + IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && + IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), + "TensorDescriptor.weight_size()") && + IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), + "TensorDescriptor.reuse_input()") && + IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), + "TensorDescriptor.output_tensor()") && + IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), + "TensorDescriptor.device_type()") && + IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), + "TensorDescriptor.input_tensor()") && + IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), + "TensorDescriptor.real_dim_cnt()") && + IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), + "TensorDescriptor.reuse_input_index()") && + IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), + "TensorDescriptor.data_offset()") && + IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && + IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && + IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), + "TensorDescriptor.cmps_tab_offset()")); + } else { + return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); + } +} + +bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const { + return GeTensorDescAttrsAreEqual(r_ge_tensor_desc); +} + +GeShape &GeTensorDesc::ShapeReference() const { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape()); + __shape_.RefTo(refShape); + } else { + GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr); + __shape_.RefTo(refShape); + } + return __shape_; +} + +void GeTensorDesc::Init() { + SetFormat(FORMAT_ND); + SetOriginFormat(FORMAT_ND); + TensorUtils::SetDeviceType(*this, DeviceType::NPU); + if (tensor_descriptor_.GetProtoMsg() == nullptr) { + GELOGE(GRAPH_FAILED, "ProtoType nullptr."); + return; + } + tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true); +} + +ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr()); + } + return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); +} + +ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const { + if (tensor_descriptor_.GetProtoMsg() != nullptr) { + return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), + tensor_descriptor_.GetProtoMsg()->mutable_attr()); + } + return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); +} + +void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) { + ShapeReference() = std::move(shape); + SetFormat(format); + SetDataType(dt); +} +GeShape GeTensorDesc::GetShape() const { return ShapeReference(); } + +GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } + +void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } + +// set shape with -2, it stand for unknown shape +void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } + +// for unknown shape +graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { + std::vector> shape_range; + for (const auto &ele : range) { + shape_range.emplace_back(std::vector({ele.first, ele.second})); + } + auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + return ret ? GRAPH_SUCCESS : GRAPH_FAILED; +} +graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { + std::vector> shape_range; + (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + + for (const auto &ele : shape_range) { + // here must be only two elemenet because pair + if (ele.size() != 2) { + GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); + return GRAPH_FAILED; + } + std::pair pair({ele[0], ele[1]}); + range.emplace_back(pair); + } + + return GRAPH_SUCCESS; +} + +GeShape GeTensorDesc::GetOriginShape() const { + vector origin_shape; + if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { + return GeShape(); + } + return GeShape(origin_shape); +} + +void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) { + std::vector origin_shape_tmp = origin_shape.GetDims(); + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp); +} + +Format GeTensorDesc::GetFormat() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout()); + } + return FORMAT_RESERVED; +} + +void GeTensorDesc::SetFormat(Format format) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format)); + } +} + +void GeTensorDesc::SetName(const std::string &name) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_name(name); + return; + } + GELOGW("[SetName]tensor_descriptor_msg is null."); +} + +const std::string GeTensorDesc::GetName() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return tensor_descriptor_msg->name(); + } + GELOGW("[GetName]tensor_descriptor_msg is null."); + return ""; +} + +Format GeTensorDesc::GetOriginFormat() const { + std::string origin_format_str; + if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) { + // Can not get the certificate and it's not set, return directly + return FORMAT_RESERVED; + } + if (origin_format_str == "RESERVED") { + return FORMAT_RESERVED; + } + return TypeUtils::SerialStringToFormat(origin_format_str); +} + +void GeTensorDesc::SetOriginFormat(Format origin_format) { + std::string origin_format_str = "RESERVED"; + if (origin_format != FORMAT_RESERVED) { + origin_format_str = TypeUtils::FormatToSerialString(origin_format); + } + (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str); +} + +DataType GeTensorDesc::GetDataType() const { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg == nullptr) { + return DT_UNDEFINED; + } + auto &attr_map = *(tensor_descriptor_msg->mutable_attr()); + // Data type + auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined); + if (it_data_type != attr_map.end()) { + int64_t data_type_proto = it_data_type->second.i(); + for (auto it : kDataTypeSelfDefinedMap) { + if (it.second == data_type_proto) { + return it.first; + } + } + } else { + auto data_type_proto = tensor_descriptor_msg->dtype(); + for (auto it : kDataTypeMap) { + if (it.second == data_type_proto) { + return it.first; + } + } + } + return DT_UNDEFINED; +} + +void GeTensorDesc::SetDataType(DataType dataType) { + auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg == nullptr) { + return; + } + auto &attr_maps = *(tensor_descriptor_msg->mutable_attr()); + (void)attr_maps.erase(kKeyDataTypeSelfDefined); + + // Data type + auto it = kDataTypeMap.find(dataType); + if (it != kDataTypeMap.end()) { + tensor_descriptor_msg->set_dtype(it->second); + return; + } + auto it2 = kDataTypeSelfDefinedMap.find(dataType); + if (it2 != kDataTypeSelfDefinedMap.end()) { + attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second); + } +} + +void GeTensorDesc::SetOriginDataType(DataType origin_data_type) { + std::string origin_data_type_str = "RESERVED"; + if (origin_data_type != DT_UNDEFINED) { + origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type); + } + (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str); +} + +DataType GeTensorDesc::GetOriginDataType() const { + std::string origin_data_type_str; + if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) { + return DT_UNDEFINED; + } + if (origin_data_type_str == "RESERVED") { + return DT_UNDEFINED; + } + return TypeUtils::SerialStringToDataType(origin_data_type_str); +} + +std::vector GeTensorDesc::GetRefPortIndex() const { + vector ref_port_index; + (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); + return ref_port_index; +} + +void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); +} + +graphStatus GeTensorDesc::IsValid() const { + auto dtype = this->GetDataType(); + auto format = this->GetFormat(); + if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) { + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +GeTensorDesc GeTensorDesc::Clone() const { return *this; } + +GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) { + if (&desc != this) { + tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); + } + return *this; +} + +GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) { + if (&desc != this) { + tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_)); + } + return *this; +} + +GeTensor::GeTensor::GeTensor() { + tensor_def_.InitDefault(); + // Default init desc + DescReference() = GeTensorDesc(); +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; } + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector &data) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_data(data.data(), data.size()); + } +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr && data != nullptr) { + proto_msg->set_data(data, size); + } +} + +GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector &&data) : GeTensor() { + DescReference() = std::move(tensor_desc); + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_data(data.data(), data.size()); + } +} + +GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() { + DescReference() = tensor_desc; + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + if (data.size() == 0) { + GELOGI("GetSize res is 0."); + } + if (data.data() == nullptr) { + GELOGI("data addr is null."); + } + proto_msg->set_data(data.GetData(), data.GetSize()); + } +} + +GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) + : tensor_def_(proto_owner, proto_msg) {} + +GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); } + +GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); } + +GeTensorDesc &GeTensor::DescReference() const { + if (tensor_def_.GetProtoMsg() != nullptr) { + GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc()); + __desc_.RefTo(tensor_desc); + } else { + GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr); + __desc_.RefTo(tensor_desc); + } + return __desc_; +} + +void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } + +const Buffer GeTensor::GetData() const { + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); + } + return Buffer(); +} + +Buffer GeTensor::MutableData() { + auto proto_msg = tensor_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); + } + return Buffer(); +} + +graphStatus GeTensor::SetData(vector &&data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const vector &data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const uint8_t *data, size_t size) { + if (size > 0) { + GE_CHECK_NOTNULL(data); + } + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + proto_msg->set_data(data, size); + return GRAPH_SUCCESS; +} + +graphStatus GeTensor::SetData(const Buffer &data) { + auto proto_msg = tensor_def_.GetProtoMsg(); + GE_CHECK_NOTNULL(proto_msg); + if (data.size() == 0) { + GELOGI("GetSize res is 0."); + } + if (data.data() == nullptr) { + GELOGI("data addr is null."); + } + proto_msg->set_data(data.data(), data.size()); + return GRAPH_SUCCESS; +} + +GeTensor GeTensor::Clone() const { + GeTensor tensor; + tensor.tensor_def_.CopyValueFrom(tensor_def_); + return tensor; +} + +GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; } + +GeTensor &GeTensor::operator=(const GeTensor &other) { + if (&other != this) { + tensor_def_ = other.tensor_def_; + } + return *this; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, + int64_t &size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + size = static_cast(tensor_descriptor_msg->size()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_size(size); + } +} + +uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + return static_cast(tensor_descriptor_msg->weight_size()); + } + return 0; +} + +uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) { + if (tensor_ptr == nullptr) { + return 0; + } + return GetWeightSize(*tensor_ptr); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, + uint8_t *base) { + if (tensor_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "tensor_ptr is null."); + return nullptr; + } + return GetWeightAddr(*tensor_ptr, base); +} + +uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) { + if (base == nullptr) { + GELOGE(GRAPH_FAILED, "base is null."); + return nullptr; + } + int64_t weight_data_offset = 0; + if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr; + + if (weight_data_offset == 0) { + // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS. + return const_cast(tensor.GetData().data()); + } + + return base + weight_data_offset; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, + uint32_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_weight_size(size); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->reuse_input(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_reuse_input(flag); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->output_tensor(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_output_tensor(flag); + } +} + +static map device_to_str_map{ + {0, "NPU"}, {1, "CPU"}, +}; +static map str_to_device_map{ + {"NPU", 0}, {"CPU", 1}, +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, + DeviceType &type) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + string type_str = tensor_descriptor_msg->device_type(); + type = DeviceType(str_to_device_map[type_str]); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, + DeviceType type) { + auto type_str = device_to_str_map[type]; + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_device_type(type_str); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, + bool &flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + flag = tensor_descriptor_msg->input_tensor(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_input_tensor(flag); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, + uint32_t &cnt) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + cnt = static_cast(tensor_descriptor_msg->real_dim_cnt()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, + uint32_t cnt) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_real_dim_cnt(cnt); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + GE_CHECK_NOTNULL(tensor_descriptor_msg); + + idx = static_cast(tensor_descriptor_msg->reuse_input_index()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, + uint32_t idx) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_reuse_input_index(idx); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, + int64_t &offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + offset = tensor_descriptor_msg->data_offset(); + return GRAPH_SUCCESS; + } else { + GELOGW("tensor_descriptor_msg is nullptr."); + return GRAPH_FAILED; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, + int64_t offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_data_offset(offset); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc, + uint32_t &cmp_size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + cmp_size = static_cast(tensor_descriptor_msg->cmps_size()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc, + uint32_t cmp_size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_cmps_size(cmp_size); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc, + vector &vec) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + string str = tensor_descriptor_msg->cmps_tab(); + vec.assign(str.begin(), str.end()); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc, + const uint8_t *data, size_t size) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null."); + string str((const char *)data, size); + tensor_descriptor_msg->set_cmps_tab(str); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tab_offset = tensor_descriptor_msg->cmps_tab_offset(); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc, + int64_t tab_offset) { + auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor_msg != nullptr) { + tensor_descriptor_msg->set_cmps_tab_offset(tab_offset); + } +} + +graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) { + GeAttrValue attr_value; + if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return attr_value.GetValue(info); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc, + const CompressInfo &info) { + (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( + const GeTensorDesc &tensor_desc) { + return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) { + GeAttrValue attr_value; + if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) { + GELOGW("get attr alloffset_quantize_info fail."); + } + return attr_value.GetValue(info); +} + +void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) { + (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, + uint32_t &rc) { + return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) { + (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc); +} +} // namespace ge diff --git a/metadef/graph/gnode.cc b/metadef/graph/gnode.cc new file mode 100644 index 00000000..e75593cc --- /dev/null +++ b/metadef/graph/gnode.cc @@ -0,0 +1,877 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/gnode.h" + +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/utils/node_adapter.h" +#include "graph/utils/tensor_adapter.h" +#include +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_op_types.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" + +namespace ge { +class NodeImpl { + public: + NodeImpl() = default; + ~NodeImpl() = default; + + NodeImpl(NodeImpl &) = delete; + NodeImpl &operator=(const NodeImpl &) = delete; + + std::weak_ptr node_ptr_; +}; + +NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) { + if (graph_node.impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr."); + return nullptr; + } + + return graph_node.impl_->node_ptr_.lock(); +} + +GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr"); + return GNode(); + } + + GNode graph_node; + if (graph_node.impl_ == nullptr) { + GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); + return graph_node; + } + graph_node.impl_->node_ptr_ = node; + + return graph_node; +} + +GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr"); + return nullptr; + } + + GNodePtr gnode = std::shared_ptr(new (std::nothrow) GNode()); + if (gnode == nullptr) { + GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str()); + return nullptr; + } + + if (gnode->impl_ == nullptr) { + GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); + return nullptr; + } + gnode->impl_->node_ptr_ = node; + + return gnode; +} + +GNode::GNode() { impl_ = ComGraphMakeShared(); } + + +graphStatus GNode::GetType(AscendString &type) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid."); + return GRAPH_FAILED; + } + std::string node_type = node_ptr->GetType(); + AscendString ascend_type(node_type.c_str()); + type = ascend_type; + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetName(AscendString &name) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid."); + return GRAPH_FAILED; + } + std::string node_name = node_ptr->GetName(); + AscendString ascend_name(node_name.c_str()); + name = ascend_name; + + return GRAPH_SUCCESS; +} + +std::pair GNode::GetInDataNodesAndPortIndexs(const int32_t index) const { + pair gnode_idx = {nullptr, 0xFF}; + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); + return gnode_idx; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); + return gnode_idx; + } + + auto in_anchor = node_ptr->GetInDataAnchor(index); + if (in_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist", + index, node_ptr->GetName().c_str()); + return gnode_idx; + } + + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist", + index, node_ptr->GetName().c_str()); + return gnode_idx; + } + + NodePtr peer_node_ptr = out_anchor->GetOwnerNode(); + GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); + if (gnode == nullptr) { + GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); + return gnode_idx; + } + + return {gnode, out_anchor->GetIdx()}; +} + +std::vector GNode::GetInControlNodes() const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); + return {}; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); + return {}; + } + + std::vector gnodes; + auto in_control_nodes = node_ptr->GetInControlNodes(); + for (auto &in_control_node : in_control_nodes) { + GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node); + if (gnode == nullptr) { + GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); + return {}; + } + gnodes.emplace_back(gnode); + } + + return gnodes; +} + +std::vector> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); + return {}; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); + return {}; + } + + auto out_anchor = node_ptr->GetOutDataAnchor(index); + if (out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists", + index, node_ptr->GetName().c_str()); + return {}; + } + + vector> gnode_index; + auto in_data_anchors = out_anchor->GetPeerInDataAnchors(); + for (auto &in_data_anchor : in_data_anchors) { + if (in_data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str()); + return {}; + } + NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode(); + GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); + if (gnode == nullptr) { + GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); + return {}; + } + gnode_index.emplace_back(std::pair(gnode, in_data_anchor->GetIdx())); + } + + return gnode_index; +} + +std::vector GNode::GetOutControlNodes() const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr."); + return {}; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid."); + return {}; + } + + std::vector gnodes; + auto out_control_nodes = node_ptr->GetOutControlNodes(); + for (auto &out_control_node : out_control_nodes) { + GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node); + if (gnode == nullptr) { + GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); + return {}; + } + gnodes.emplace_back(gnode); + } + + return gnodes; +} + +graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); + GE_CHECK_NOTNULL(input_data_node); + string op_type = input_data_node->GetType(); + if (op_type == CONSTANT || op_type == CONSTANTOP) { + Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); + if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", + input_data_node->GetName().c_str(), node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + return SUCCESS; + } else if (op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(input_data_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) && + ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node); + if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", + parent_node->GetName().c_str(), node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + } + + GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str()); + return GRAPH_NODE_WITHOUT_CONST_INPUT; +} + +graphStatus GNode::GetInputIndexByName(const AscendString &name, int32_t &index) { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + std::string node_name = ascend_name; + index = op_desc->GetInputIndexByName(node_name); + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetOutputIndexByName(const AscendString &name, int32_t &index) { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + std::string node_name = ascend_name; + index = op_desc->GetOutputIndexByName(node_name); + + return GRAPH_SUCCESS; +} + +size_t GNode::GetInputsSize() const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return op_desc->GetInputsSize(); +} + +size_t GNode::GetOutputsSize() const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return op_desc->GetOutputsSize(); +} + +graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const { + if (index < 0) { + GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast(index)); + if (ge_tensor_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); + + return GRAPH_SUCCESS; +} + +graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) { + if (index < 0) { + GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); + if (op_desc->UpdateInputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const { + if (index < 0) { + GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast(index)); + if (ge_tensor_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); + + return GRAPH_SUCCESS; +} + +graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) { + if (index < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid."); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); + if (op_desc->UpdateOutputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +#define NODE_ATTR_GET_IMP(ArgType) \ + graphStatus GNode::GetAttr(const AscendString &name, ArgType &attr_value) const { \ + const char* ascend_name = name.GetString(); \ + if (ascend_name == nullptr) { \ + GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \ + return GRAPH_PARAM_INVALID; \ + } \ + \ + if (impl_ == nullptr) { \ + GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \ + return GRAPH_FAILED; \ + } \ + \ + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); \ + if (node_ptr == nullptr) { \ + GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \ + return GRAPH_FAILED; \ + } \ + \ + std::string node_name = ascend_name; \ + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ + if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \ + GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \ + return GRAPH_FAILED; \ + } \ + \ + return GRAPH_SUCCESS; \ + } + +#define NODE_ATTR_SET_IMP(ArgType) \ + graphStatus GNode::SetAttr(const AscendString &name, ArgType &attr_value) const { \ + const char* ascend_name = name.GetString(); \ + if (ascend_name == nullptr) { \ + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \ + return GRAPH_PARAM_INVALID; \ + } \ + \ + if (impl_ == nullptr) { \ + GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \ + return GRAPH_FAILED; \ + } \ + \ + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); \ + if (node_ptr == nullptr) { \ + GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \ + return GRAPH_FAILED; \ + } \ + \ + std::string node_name = ascend_name; \ + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ + (void)op.SetAttr(node_name, attr_value); \ + return GRAPH_SUCCESS; \ + } + +NODE_ATTR_GET_IMP(int64_t) +NODE_ATTR_GET_IMP(int32_t) +NODE_ATTR_GET_IMP(uint32_t) +NODE_ATTR_GET_IMP(float) +NODE_ATTR_GET_IMP(bool) +NODE_ATTR_GET_IMP(Tensor) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(OpBytes) +NODE_ATTR_GET_IMP(std::vector>) +NODE_ATTR_GET_IMP(std::vector) +NODE_ATTR_GET_IMP(ge::DataType) +NODE_ATTR_GET_IMP(AttrValue) + +NODE_ATTR_SET_IMP(int64_t) +NODE_ATTR_SET_IMP(int32_t) +NODE_ATTR_SET_IMP(uint32_t) +NODE_ATTR_SET_IMP(float) +NODE_ATTR_SET_IMP(bool) +NODE_ATTR_SET_IMP(Tensor) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(OpBytes) +NODE_ATTR_SET_IMP(std::vector>) +NODE_ATTR_SET_IMP(std::vector) +NODE_ATTR_SET_IMP(ge::DataType) + +graphStatus GNode::SetAttr(const AscendString &name, AttrValue &attr_value) const { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); + return GRAPH_FAILED; + } + + std::string node_name = ascend_name; + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); + (void)op.SetAttr(node_name, std::move(attr_value)); + return GRAPH_SUCCESS; +} + +graphStatus GNode::SetAttr(const AscendString &name, AscendString &attr_value) const { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); + return GRAPH_PARAM_INVALID; + } + + const char* ascend_attr_value = attr_value.GetString(); + if (ascend_attr_value == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); + return GRAPH_FAILED; + } + std::string node_name = ascend_name; + std::string node_attr_value = ascend_attr_value; + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); + (void)op.SetAttr(node_name, node_attr_value); + + return GRAPH_SUCCESS; +} + +graphStatus GNode::SetAttr(const AscendString &name, std::vector &attr_values) const { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); + return GRAPH_PARAM_INVALID; + } + + for (auto &attr_val : attr_values) { + const char* ascend_attr_value = attr_val.GetString(); + if (ascend_attr_value == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error."); + return GRAPH_PARAM_INVALID; + } + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); + return GRAPH_FAILED; + } + vector node_attr_vals; + for (auto attr_val : attr_values) { + if (attr_val.GetString() != nullptr) { + std::string node_attr_val = attr_val.GetString(); + node_attr_vals.emplace_back(node_attr_val); + } + } + std::string node_name = ascend_name; + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); + (void)op.SetAttr(node_name, node_attr_vals); + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetAttr(const AscendString &name, AscendString &attr_value) const { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + std::string node_name = ascend_name; + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); + std::string op_name; + if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + AscendString attr_value_get(op_name.c_str()); + attr_value = attr_value_get; + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetAttr(const AscendString &name, std::vector &attr_values) const { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); + return GRAPH_PARAM_INVALID; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + std::string node_name = ascend_name; + Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); + vector attr_names; + if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + for (auto &attr_name : attr_names) { + AscendString ascend_attr_name(attr_name.c_str()); + attr_values.push_back(ascend_attr_name); + } + + return GRAPH_SUCCESS; +} + +bool GNode::HasAttr(const AscendString &name) { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error."); + return false; + } + + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr."); + return false; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid."); + return false; + } + + OpDescPtr op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); + return false; + } + std::string attr_name = ascend_name; + if (!op_desc->HasAttr(attr_name)) { + GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str()); + return false; + } + + return true; +} + +graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + graph = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph_ptr); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus GNode::GetALLSubgraphs(std::vector &graph_list) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr."); + return GRAPH_FAILED; + } + + std::shared_ptr node_ptr = impl_->node_ptr_.lock(); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid."); + return GRAPH_FAILED; + } + + std::vector sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr); + if (sub_graphs.empty()) { + GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed from node[%s].", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + for (auto &sub_graph : sub_graphs) { + if (sub_graph == nullptr) { + GELOGE(GRAPH_FAILED, "Get subgraph failed from node[%s].", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + GraphPtr graph = GraphUtils::CreateGraphPtrFromComputeGraph(sub_graph); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "Subgraph create compute graph failed from node[%s].", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + graph_list.emplace_back(graph); + } + + if (graph_list.empty()) { + GELOGW("Node[%s] has no subgraph.", node_ptr->GetName().c_str()); + } + + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/graph.cc b/metadef/graph/graph.cc new file mode 100644 index 00000000..6b60696c --- /dev/null +++ b/metadef/graph/graph.cc @@ -0,0 +1,810 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/graph.h" +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_op_types.h" +#include "graph/model.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/node_adapter.h" +#include "graph/utils/node_utils.h" + +using std::map; +using std::pair; +using std::string; +using std::vector; + +namespace ge { +class GraphImpl { + public: + friend class GraphUtils; + GraphImpl(const GraphImpl &) = delete; + GraphImpl &operator=(const GraphImpl &) = delete; + + explicit GraphImpl(const std::string &name) : name_(name) {} + + ~GraphImpl() { + if (IsValid()) { + if (compute_graph_ != nullptr) { + GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo()); + } + } + for (const auto &it : op_list_) { + Operator op = it.second; + op.BreakConnect(); + } + } + + graphStatus SetInputs(const std::vector &inputs) { + compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs); + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed."); + GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL."); + compute_graph_->SetInputSize(static_cast(inputs.size())); + return GRAPH_SUCCESS; + } + + graphStatus SetOutputs(const std::vector &outputs) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); + return GRAPH_FAILED; + } + if (outputs.empty()) { + GELOGW("set outputs size is 0."); + return GRAPH_SUCCESS; + } + + // Construct special output node + std::vector>> output_indexs; + for (size_t i = 0; i < outputs.size(); ++i) { + output_indexs.emplace_back(outputs[i], std::vector{}); + } + + graphStatus ret = SetOutputs(output_indexs); + return ret; + } + + graphStatus SetOutputs(const std::vector>> &output_indexs) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); + return GRAPH_FAILED; + } + if (output_indexs.empty()) { + GELOGW("set outputs size is 0."); + return GRAPH_SUCCESS; + } + + // Construct special output node + std::vector> output_nodes; + for (const auto &item : output_indexs) { + const Operator &output = item.first; + const vector &indexs = item.second; + ge::NodePtr node = compute_graph_->FindNode(output.GetName()); + if (node == nullptr) { + GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str()); + continue; + } + + ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); + size_t out_size = tmp_op_ptr->GetOutputsSize(); + if (indexs.empty()) { + for (size_t i = 0; i < out_size; ++i) { + output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.emplace_back(node, i); + } + } else { + for (size_t i = 0; i < indexs.size(); ++i) { + if (indexs[i] >= out_size) { + GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str()); + } else { + output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.emplace_back(node, indexs[i]); + } + } + } + } + + // Del last ";" + if (!output_name_.empty()) { + output_name_ = output_name_.substr(0, output_name_.length() - 1); + } + compute_graph_->SetUserDefOutput(output_name_); + compute_graph_->SetOutputSize(static_cast(output_indexs.size())); + compute_graph_->SetGraphOutNodesInfo(output_nodes); + return GRAPH_SUCCESS; + } + + graphStatus SetOutputs(const std::vector> &outputs) { + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); + GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0."); + + // Construct specified output + std::vector> output_nodes; + for (auto item : outputs) { + ge::NodePtr node = compute_graph_->FindNode(item.first.GetName()); + if (node == nullptr) { + GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!", + item.first.GetName().c_str()); + return GRAPH_FAILED; + } + ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); + size_t out_size = tmp_op_ptr->GetOutputsSize(); + + if (item.second.empty()) { + for (size_t i = 0; i < out_size; ++i) { + output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";"; + output_nodes.push_back(std::make_pair(node, i)); + } + } else { + int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); + if (index < 0) { + GELOGE(GRAPH_FAILED, + " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!", + item.first.GetName().c_str(), item.second.c_str()); + return GRAPH_FAILED; + } + output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";"; + output_nodes.push_back(std::make_pair(node, index)); + } + } + // Del last ";" + if (!output_name_.empty()) { + output_name_ = output_name_.substr(0, output_name_.length() - 1); + } + compute_graph_->SetOutputSize(static_cast(outputs.size())); + compute_graph_->SetGraphOutNodesInfo(output_nodes); + GELOGI("********************SetOutputs Success***********************"); + GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); + + return GRAPH_SUCCESS; + } + + graphStatus SetTargets(const std::vector &targets) { + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); + GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0."); + + std::vector target_nodes; + for (auto item : targets) { + ge::NodePtr node = compute_graph_->FindNode(item.GetName()); + if (node == nullptr) { + GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!", + item.GetName().c_str()); + continue; + } + target_nodes.push_back(node); + } + compute_graph_->SetGraphTargetNodesInfo(target_nodes); + return GRAPH_SUCCESS; + } + bool IsValid() const { return (compute_graph_ != nullptr); } + + graphStatus AddOp(const ge::Operator &op) { + std::pair::iterator, bool> ret; + ret = op_list_.emplace(std::pair(op.GetName(), op)); + GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.", + op.GetName().c_str()); + return GRAPH_SUCCESS; + } + + graphStatus GetAllOpName(std::vector &op_name) const { + for (const auto &it : op_list_) { + op_name.push_back(it.second.GetName()); + } + return GRAPH_SUCCESS; + } + + graphStatus FindOpByName(const string &name, ge::Operator &op) const { + auto it = op_list_.find(name); + GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); + op = it->second; + return GRAPH_SUCCESS; + } + + graphStatus FindOpByType(const string &type, std::vector &ops) const { + for (auto &op : op_list_) { + auto op_type = op.second.GetOpType(); + if (op_type == type) { + ops.push_back(op.second); + continue; + } + if (op_type == ge::FRAMEWORKOP) { + op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); + if (op_type == type) { + ops.push_back(op.second); + } + } + } + return GRAPH_SUCCESS; + } + + void SetNeedIteration(bool need_iteration) { + if (compute_graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); + return; + } + compute_graph_->SetNeedIteration(need_iteration); + } + + const std::string &GetName() const { + return name_; + } + + ComputeGraphPtr GetComputeGraph() const { + return compute_graph_; + } + + graphStatus RemoveEdge(NodePtr &src_node_ptr, const int32_t src_port_index, + NodePtr &dst_node_ptr, const int32_t dst_port_index) { + GE_CHECK_NOTNULL(src_node_ptr); + GE_CHECK_NOTNULL(dst_node_ptr); + + graphStatus res = GRAPH_FAILED; + if ((src_port_index == -1) && (dst_port_index == -1)) { + if (src_node_ptr->GetOutControlAnchor() == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out control anchor is null.", src_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); + if (res != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge between [%s] and [%s]failed.", + src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + + if (src_node_ptr->GetOutDataAnchor(src_port_index) == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out data anchor[%d] is null.", + src_node_ptr->GetName().c_str(), src_port_index); + return GRAPH_FAILED; + } + + if (src_port_index != -1 && dst_port_index == -1) { + res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor()); + if (res != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge between [%s] and [%s]failed.", + src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + + res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), + dst_node_ptr->GetInDataAnchor(dst_port_index)); + if (res != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge between [%s] and [%s] failed.", + src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; + } + + private: + std::string name_; + std::string output_name_; + std::map op_list_; + ComputeGraphPtr compute_graph_{nullptr}; +}; + +Graph::Graph(const std::string &name) { + impl_ = ComGraphMakeShared(name); + if (impl_ == nullptr) { + GELOGW("GraphImpl make shared failed, impl_ is nullptr"); + } +} + +Graph::Graph(const char *name) { + if (name != nullptr) { + std::string graph_name = name; + impl_ = ComGraphMakeShared(graph_name); + if (impl_ == nullptr) { + GELOGW("GraphImpl make shared failed, impl_ is nullptr."); + } + } else { + GELOGW("Graph name is nullptr."); + } +} + +graphStatus Graph::AddOp(const ge::Operator &op) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr."); + return impl_->AddOp(op); +} + +graphStatus Graph::GetAllOpName(std::vector &op_name) const { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "GetAllOpName failed: graph can not be used, impl is nullptr."); + return impl_->GetAllOpName(op_name); +} + +graphStatus Graph::GetAllOpName(std::vector &names) const { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "GetAllOpName failed: graph can not be used, impl is nullptr."); + std::vector op_names; + if (impl_->GetAllOpName(op_names) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get all op name failed."); + return GRAPH_FAILED; + } + + for (auto &op_name : op_names) { + names.emplace_back(op_name.c_str()); + } + + return GRAPH_SUCCESS; +} + +graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { + Operator op_find_op_def("NULL"); + op = op_find_op_def; + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "FindOpByName failed: graph can not be used, impl is nullptr."); + return impl_->FindOpByName(name, op); +} + +graphStatus Graph::FindOpByName(const char *name, Operator &op) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "FindOpByName: name is nullptr."); + return GRAPH_FAILED; + } + Operator op_find_op_def("NULL"); + op = op_find_op_def; + GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, + "FindOpByName failed: graph can not be used, impl is nullptr."); + std::string op_name = name; + return impl_->FindOpByName(op_name, op); +} + +graphStatus Graph::FindOpByType(const string &type, std::vector &ops) const { + GE_CHECK_NOTNULL(impl_); + return impl_->FindOpByType(type, ops); +} + +graphStatus Graph::FindOpByType(const char *type, std::vector &ops) const { + if (type == nullptr) { + GELOGE(GRAPH_FAILED, "FindOpByType: name is nullptr."); + return GRAPH_FAILED; + } + GE_CHECK_NOTNULL(impl_); + std::string op_type = type; + return impl_->FindOpByType(op_type, ops); +} + +Graph &Graph::SetInputs(const vector &inputs) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") + GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); + (void)impl_->SetInputs(inputs); + return *this; +} + +Graph &Graph::SetOutputs(const vector &outputs) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetOutputs(outputs); + return *this; +} + +Graph &Graph::SetOutputs(const std::vector>> &output_indexs) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetOutputs(output_indexs); + return *this; +} + +Graph &Graph::SetOutputs(const std::vector> &outputs) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") + (void)impl_->SetOutputs(outputs); + return *this; +} + +Graph &Graph::SetOutputs(const std::vector> &outputs) { + GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") + vector> graph_outputs; + for (auto &item : outputs) { + const char *name = item.second.GetString(); + if (name != nullptr) { + string output_name = name; + graph_outputs.emplace_back((std::pair(item.first, name))); + } else { + GELOGW("Output name is nullptr."); + } + } + + (void)impl_->SetOutputs(graph_outputs); + return *this; +} + +Graph &Graph::SetTargets(const vector &targets) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr."); + return *this; + } + (void)impl_->SetTargets(targets); + return *this; +} + +bool Graph::IsValid() const { + if (impl_ == nullptr) { + return false; + } + return impl_->IsValid(); +} + +void Graph::SetNeedIteration(bool need_iteration) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null."); + return; + } + impl_->SetNeedIteration(need_iteration); +} + +std::vector Graph::GetAllNodes() const { + std::vector graph_nodes; + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetAllNodes: graph can not be used, impl is nullptr."); + return graph_nodes; + } + + ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetAllNodes: compute graph ptr is nullptr."); + return graph_nodes; + } + + for (auto &node : compute_graph_ptr->GetAllNodes()) { + GNode gnode = NodeAdapter::Node2GNode(node); + graph_nodes.emplace_back(gnode); + } + + return graph_nodes; +} + +std::vector Graph::GetDirectNode() const { + std::vector graph_nodes; + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetDirectNode: graph can not be used, impl is nullptr."); + return graph_nodes; + } + ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "GetDirectNode: compute graph ptr is nullptr."); + return graph_nodes; + } + + for (auto &node : compute_graph_ptr->GetDirectNode()) { + GNode gnode = NodeAdapter::Node2GNode(node); + graph_nodes.emplace_back(gnode); + } + + return graph_nodes; +} + +graphStatus Graph::RemoveNode(GNode &node) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveNode: graph can not be used, impl is nullptr."); + return GRAPH_FAILED; + } + + NodePtr node_ptr = NodeAdapter::GNode2Node(node); + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveNode: gnode to node failed."); + return GRAPH_FAILED; + } + + if (node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveNode: node[%s] is invalid.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr."); + return GRAPH_FAILED; + } + + ge::NodeUtils::UnlinkAll(*node_ptr); + if (GraphUtils::RemoveNodeWithoutRelink(compute_graph_ptr, node_ptr) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "RemoveNode: remove node[%s] failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + node_ptr->SetAnyOwnerComputeGraph(nullptr); + + return GRAPH_SUCCESS; +} + +graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index, + GNode &dst_node, const int32_t dst_port_index) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: graph can not be used, impl is nullptr."); + return GRAPH_FAILED; + } + + if ((src_port_index == -1) && (dst_port_index != -1)) { + GELOGE(GRAPH_FAILED, "RemoveEdge:src control anchor link to dst data anchor not exists."); + return GRAPH_FAILED; + } + + NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); + if (src_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: src gnode to node failed."); + return GRAPH_FAILED; + } + + NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); + if (dst_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: dst gnode to node failed."); + return GRAPH_FAILED; + } + + if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "RemoveEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + if (impl_->RemoveEdge(src_node_ptr, src_port_index, dst_node_ptr, dst_port_index) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "RemoveEdge: remove edge failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +GNode Graph::AddNodeByOp(const Operator &op) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "AddNodeByOp: graph can not be used, impl is nullptr."); + return GNode(); + } + + std::shared_ptr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddNodeByOp: get op desc from op[%s] failed.", op.GetName().c_str()); + return GNode(); + } + + ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "AddNodeByOp: compute graph ptr is nullptr."); + return GNode(); + } + + NodePtr node_ptr = compute_graph_ptr->AddNode(op_desc); + GNode gnode = NodeAdapter::Node2GNode(node_ptr); + + return gnode; +} + +graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index, + GNode &dst_node, const int32_t dst_port_index) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "AddDataEdge: graph can not be used, impl is nullptr."); + return GRAPH_FAILED; + } + + NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); + if (src_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "AddDataEdge: src gnode to node failed."); + return GRAPH_FAILED; + } + + NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); + if (dst_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "AddDataEdge: dst gnode to node failed."); + return GRAPH_FAILED; + } + + if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "AddDataEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "AddDataEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), + dst_node_ptr->GetInDataAnchor(dst_port_index)); + if (res != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddDataEdge: Add data edge failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus Graph::AddControlEdge (GNode &src_node, GNode &dst_node) { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "AddControlEdge: graph can not be used, impl is nullptr."); + return GRAPH_FAILED; + } + + NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); + if (src_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "AddControlEdge: src gnode to node failed."); + return GRAPH_FAILED; + } + + NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); + if (dst_node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "AddControlEdge: dst gnode to node failed."); + return GRAPH_FAILED; + } + + if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "AddControlEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { + GELOGE(GRAPH_FAILED, "AddControlEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); + if (res != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed."); + return GRAPH_FAILED; + } + + return SUCCESS; +} + +GraphPtr Graph::ConstructFromInputs(const std::vector &inputs, const AscendString &name) { + const char* ascend_name = name.GetString(); + if (ascend_name == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "ConstructFromInputs: ascend string error."); + return nullptr; + } + + if (inputs.empty()) { + GELOGE(GRAPH_FAILED, "ConstructFromInputs: inputs size can not be 0."); + return nullptr; + } + + std::string graph_name = ascend_name; + ComputeGraphPtr compute_graph = GraphUtils::CreateGraphFromOperator(graph_name, inputs); + if (compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "ConstructFromInputs: create compute graph failed."); + return nullptr; + } + + compute_graph->SetInputSize(static_cast(inputs.size())); + GraphPtr graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph); + if (graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "ConstructFromInputs: create graph from compute graph failed."); + return nullptr; + } + + return graph_ptr; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { + GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); + return graph.impl_->compute_graph_; +} + +graphStatus Graph::SaveToFile(const string &file_name) const { + Model model = Model(); + model.SetGraph(*this); + return model.SaveToFile(file_name); +} + +graphStatus Graph::SaveToFile(const char *file_name) const { + if (file_name == nullptr) { + GELOGE(GRAPH_FAILED, "SaveToFile: file name is nullptr."); + return GRAPH_FAILED; + } + + Model model = Model(); + model.SetGraph(*this); + std::string file = file_name; + return model.SaveToFile(file); +} + +graphStatus Graph::LoadFromFile(const string &file_name) { + Model model = Model(); + graphStatus ret = model.LoadFromFile(file_name); + if (ret != GRAPH_SUCCESS) { + return ret; + } + *this = model.GetGraph(); + return GRAPH_SUCCESS; +} + +graphStatus Graph::LoadFromFile(const char *file_name) { + if (file_name == nullptr) { + GELOGE(GRAPH_FAILED, "SaveToFile: file name is nullptr."); + return GRAPH_FAILED; + } + + Model model = Model(); + std::string file = file_name; + graphStatus ret = model.LoadFromFile(file); + if (ret != GRAPH_SUCCESS) { + return ret; + } + *this = model.GetGraph(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +const std::string &Graph::GetName() const { + return impl_->GetName(); +} + +graphStatus Graph::GetName(AscendString &name) const { + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "GetName: impl is nullptr."); + return GRAPH_FAILED; + } + std::string graph_name = impl_->GetName(); + name = AscendString(graph_name.c_str()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph +GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { + GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); + + auto name = compute_graph->GetName(); + auto graph = Graph(name); + + GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph); + graph.impl_->compute_graph_ = compute_graph; + + return graph; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphPtr +GraphUtils::CreateGraphPtrFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { + GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return nullptr); + + auto name = compute_graph->GetName(); + auto graph = ComGraphMakeShared(name); + GE_CHK_BOOL_EXEC_NOLOG(graph != nullptr, return nullptr); + GE_CHK_BOOL_EXEC_NOLOG(graph->impl_ != nullptr, return nullptr); + + graph->impl_->compute_graph_ = compute_graph; + + return graph; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { + GE_CHECK_NOTNULL(graph.impl_); + GE_CHECK_NOTNULL(graph.impl_->compute_graph_); + + graph.impl_->op_list_.clear(); + for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { + graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); + } + return SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/graph.mk b/metadef/graph/graph.mk new file mode 100644 index 00000000..14f49889 --- /dev/null +++ b/metadef/graph/graph.mk @@ -0,0 +1,323 @@ +LOCAL_PATH := $(call my-dir) +include $(LOCAL_PATH)/stub/Makefile +COMMON_LOCAL_SRC_FILES := \ + ./proto/om.proto \ + ./proto/ge_ir.proto \ + ./proto/ge_onnx.proto \ + ./proto/insert_op.proto \ + ./proto/task.proto \ + ./proto/fwk_adapter.proto \ + ./proto/op_mapping_info.proto \ + ./proto/dump_task.proto \ + ./anchor.cc \ + ./ge_attr_value.cc \ + ./attr_value.cc \ + ./buffer.cc \ + ./compute_graph.cc \ + ./ascend_string.cc \ + ./gnode.cc \ + ./graph.cc \ + ./inference_context.cc \ + ./shape_refiner.cc \ + ./format_refiner.cc \ + ./ref_relation.cc \ + ./model.cc \ + ./model_serialize.cc \ + ./node.cc \ + ./op_desc.cc \ + ./operator.cc \ + ./operator_factory.cc \ + ./operator_factory_impl.cc \ + ./ge_attr_define.cc \ + ./ge_tensor.cc \ + ./detail/attributes_holder.cc \ + ./utils/anchor_utils.cc \ + ./utils/tuning_utils.cc \ + ./utils/graph_utils.cc \ + ./utils/ge_ir_utils.cc \ + ./utils/op_desc_utils.cc \ + ./utils/type_utils.cc \ + ./utils/tensor_utils.cc \ + ./tensor.cc \ + ./debug/graph_debug.cc \ + ./opsproto/opsproto_manager.cc \ + ../ops/op_imp.cpp \ + option/ge_context.cc \ + option/ge_local_context.cc \ + ./runtime_inference_context.cc \ + ./utils/node_utils.cc \ + ../third_party/transformer/src/axis_util.cpp \ + ../third_party/transformer/src/transfer_shape_according_to_format.cpp \ + ./utils/transformer_utils.cc \ + + +COMMON_LOCAL_C_INCLUDES := \ + proto/om.proto \ + proto/ge_ir.proto \ + proto_inner/ge_onnx.proto \ + proto/insert_op.proto \ + proto/task.proto \ + proto/fwk_adapter.proto \ + proto/op_mapping_info.proto \ + proto/dump_task.proto \ + inc \ + metadef/inc \ + graphengine/inc \ + inc/external \ + metadef/inc/external \ + graphengine/inc/external \ + metadef/inc/external/graph \ + metadef/inc/graph \ + metadef/inc/common \ + metadef \ + metadef/graph \ + third_party/protobuf/include \ + $(TOPDIR)metadef/third_party \ + $(TOPDIR)metadef/third_party/transformer/inc \ + libc_sec/include \ + ops/built-in/op_proto/inc \ + cann/ops/built-in/op_proto/inc \ + + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := \ + libmmpa \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -Wno-deprecated-declarations +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/ascend_string.cc \ + ../../out/graph/lib64/stub/gnode.cc \ + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -Wno-deprecated-declarations +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ + ../../out/graph/lib64/stub/ascend_string.cc \ + ../../out/graph/lib64/stub/gnode.cc \ + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -O2 -Dgoogle=ascend_private -Wno-deprecated-declarations + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := \ + libmmpa \ + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/ascend_string.cc \ + ../../out/graph/lib64/stub/gnode.cc \ + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ + ../../out/graph/lib64/stub/ascend_string.cc \ + ../../out/graph/lib64/stub/gnode.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +# compile for ut/st +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -Dgoogle=ascend_private + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + liberror_manager \ + +LOCAL_STATIC_LIBRARIES := \ + libmmpa \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_LLT_SHARED_LIBRARY) + + +#compiler for host static lib +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -Dgoogle=ascend_private +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_STATIC_LIBRARIES := \ + libascend_protobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_STATIC_LIBRARY) + +#compiler for device static lib +include $(CLEAR_VARS) +LOCAL_MODULE := libgraph + +LOCAL_CFLAGS += -O2 -Dgoogle=ascend_private + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_STATIC_LIBRARIES := \ + libascend_protobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + liberror_manager \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_STATIC_LIBRARY) diff --git a/metadef/graph/inference_context.cc b/metadef/graph/inference_context.cc new file mode 100644 index 00000000..575384e4 --- /dev/null +++ b/metadef/graph/inference_context.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/inference_context.h" +#include "debug/ge_util.h" + +namespace ge { +class ShapeAndTypeImpl { + public: + ShapeAndTypeImpl() = default; + ~ShapeAndTypeImpl() = default; + + ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} + + Shape shape_; + DataType data_type_ = DT_UNDEFINED; +}; + +class InferenceContextImpl { + public: + InferenceContextImpl() = default; + ~InferenceContextImpl() = default; + + // For deliver to op in pair, help to support dynamic shape + std::vector marks_; + std::vector> input_handle_shapes_and_types_; + std::vector> output_handle_shapes_and_types_; +}; + +ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared(); } + +ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { + shape_and_type_impl_ = ComGraphMakeShared(shape, data_type); +} + +void ShapeAndType::SetShape(const Shape &shape) { + if (shape_and_type_impl_ != nullptr) { + shape_and_type_impl_->shape_ = shape; + } +} + +void ShapeAndType::SetType(DataType data_type) { + if (shape_and_type_impl_ != nullptr) { + shape_and_type_impl_->data_type_ = data_type; + } +} + +Shape ShapeAndType::GetShape() const { + if (shape_and_type_impl_ != nullptr) { + return shape_and_type_impl_->shape_; + } + return Shape(); +} + +DataType ShapeAndType::GetDataType() const { + if (shape_and_type_impl_ != nullptr) { + return shape_and_type_impl_->data_type_; + } + return DT_UNDEFINED; +} + +InferenceContext::InferenceContext(std::unique_ptr &impl) { + inference_context_impl_ = std::move(impl); +} + +std::unique_ptr InferenceContext::Create() { + std::unique_ptr impl = + std::unique_ptr(new (std::nothrow) InferenceContextImpl()); + if (impl == nullptr) { + return nullptr; + } + + return std::unique_ptr(new (std::nothrow) InferenceContext(impl)); +} + +void InferenceContext::SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types) { + inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); +} + +const std::vector> &InferenceContext::GetInputHandleShapesAndTypes() const { + return inference_context_impl_->input_handle_shapes_and_types_; +} + +const std::vector> &InferenceContext::GetOutputHandleShapesAndTypes() const { + return inference_context_impl_->output_handle_shapes_and_types_; +} + +void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types) { + inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; +} + +void InferenceContext::SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types) { + inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); +} + +void InferenceContext::SetMarks(const std::vector &marks) { inference_context_impl_->marks_ = marks; } + +void InferenceContext::SetMarks(const std::vector &marks) { + std::vector impl_marks; + for (const auto &mark : marks) { + if (mark.GetString() != nullptr) { + impl_marks.emplace_back(mark.GetString()); + } + } + inference_context_impl_->marks_ = impl_marks; +} + +const std::vector &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } + +void InferenceContext::GetMarks(std::vector &marks) const { + std::vector str_marks = inference_context_impl_->marks_; + for (auto &str_mark : str_marks) { + marks.emplace_back(str_mark.c_str()); + } +} +} // namespace ge diff --git a/metadef/graph/model.cc b/metadef/graph/model.cc new file mode 100644 index 00000000..534e3590 --- /dev/null +++ b/metadef/graph/model.cc @@ -0,0 +1,192 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/model.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/model_serialize.h" +#include "mmpa/mmpa_api.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" +#include "proto/ge_ir.pb.h" + +using google::protobuf::io::FileInputStream; +using google::protobuf::io::FileOutputStream; +using google::protobuf::io::ZeroCopyInputStream; + +namespace { +const int DEFAULT_VERSION = 1; +const int ACCESS_PERMISSION_BITS = 0400; +} // namespace + +namespace ge { +void Model::Init() { + (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); + (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); + (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); + version_ = 0; +} + +Model::Model() { + attrs_.InitDefault(); + Init(); +} + +Model::Model(const string &name, const string &custom_version) + : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) { + attrs_.InitDefault(); + Init(); +} + +string Model::GetName() const { return name_; } + +void Model::SetName(const string &name) { name_ = name; } + +uint32_t Model::GetVersion() const { return version_; } + +string Model::GetPlatformVersion() const { return platform_version_; } + +void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } + +Graph Model::GetGraph() const { return graph_; } + +graphStatus Model::Save(Buffer &buffer, bool is_dump) const { + ModelSerialize serialize; + buffer = serialize.SerializeModel(*this, is_dump); + return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; } + +graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { + ModelSerialize serialize; + model = serialize.UnserializeModel(data, len); + return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus Model::SaveToFile(const string &file_name) const { + Buffer buffer; + if ((*this).Save(buffer) != GRAPH_SUCCESS) { + GE_LOGE("save to file fail."); + return GRAPH_FAILED; + } + // Write file + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str((const char *)buffer.GetData(), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + return GRAPH_FAILED; + } + char real_path[MMPA_MAX_PATH] = {0x00}; + if (strlen(file_name.c_str()) >= MMPA_MAX_PATH) { + return GRAPH_FAILED; + } + INT32 result = mmRealPath(file_name.c_str(), real_path, MMPA_MAX_PATH); + if (result != EN_OK) { + GELOGI("file %s does not exit, it will be created.", file_name.c_str()); + } + int fd = mmOpen2(real_path, M_WRONLY | M_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); + return GRAPH_FAILED; + } + bool ret = ge_proto.SerializeToFileDescriptor(fd); + if (!ret) { + GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed"); + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + return GRAPH_FAILED; + } + if (close(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + if (!ret) { + GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus Model::Load(ge::proto::ModelDef &model_def) { + ModelSerialize serialize; + *this = serialize.UnserializeModel(model_def); + return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +bool Model::IsValid() const { return graph_.IsValid(); } + +graphStatus Model::LoadFromFile(const string &file_name) { + char real_path[MMPA_MAX_PATH] = {0x00}; + if (strlen(file_name.c_str()) >= MMPA_MAX_PATH) { + return GRAPH_FAILED; + } + INT32 result = mmRealPath(file_name.c_str(), real_path, MMPA_MAX_PATH); + if (result != EN_OK) { + GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str()); + return GRAPH_FAILED; + } + int fd = mmOpen(real_path, M_RDONLY); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); + return GRAPH_FAILED; + } + + ge::proto::ModelDef model_def; + bool ret = model_def.ParseFromFileDescriptor(fd); + if (!ret) { + GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed"); + if (mmClose(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + return GRAPH_FAILED; + } + if (mmClose(fd) != 0) { + GELOGE(GRAPH_FAILED, "close file descriptor fail."); + return GRAPH_FAILED; + } + if (!ret) { + GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); + return GRAPH_FAILED; + } + return Load(model_def); +} + +ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; } + +ConstProtoAttrMapHelper Model::GetAttrMap() const { + return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); +} +} // namespace ge diff --git a/metadef/graph/model_serialize.cc b/metadef/graph/model_serialize.cc new file mode 100644 index 00000000..0c1fa636 --- /dev/null +++ b/metadef/graph/model_serialize.cc @@ -0,0 +1,767 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/model_serialize.h" +#include + +#include +#include + +#include "debug/ge_attr_define.h" +#include "debug/ge_log.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/detail/model_serialize_imp.h" +#include "proto/ge_ir.pb.h" +#include "utils/graph_utils.h" +#include "debug/ge_op_types.h" + +using std::map; +using std::string; + +namespace ge { +bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) { + auto sep = node_index.rfind(":"); + if (sep == string::npos) { + GELOGW("separator is not found in node_index."); + return false; + } + node_name = node_index.substr(0, sep); + auto index_str = node_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor, + proto::TensorDef *tensor_proto) { + GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null."); + GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null."); + + if (tensor->tensor_def_.GetProtoMsg() != nullptr) { + *tensor_proto = *tensor->tensor_def_.GetProtoMsg(); + return true; + } + return false; +} + +bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); + + op_def_proto->clear_input(); + // Inputs + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + if (in_data_anchor != nullptr) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { + op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + + std::to_string(peer_out_anchor->GetIdx())); + } else { + op_def_proto->add_input(""); + } + } + } + // Control edge + auto control_anchor = node->GetInControlAnchor(); + if (control_anchor != nullptr) { + auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors(); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { + op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1"); + } + } + } + return true; +} + +bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); + if (op_desc->op_def_.GetProtoMsg() != nullptr) { + *op_def_proto = *op_desc->op_def_.GetProtoMsg(); + //Delete unnecessary attr + if (is_dump) { + auto attr = op_def_proto->mutable_attr(); + attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); + attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); + attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); + GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), + attr->erase(ATTR_NAME_WEIGHTS)); + } + op_def_proto->clear_input_desc(); + op_def_proto->clear_output_desc(); + // Input descs + if (op_desc->GetAllInputsSize() > 0) { + auto size = static_cast(op_desc->GetAllInputsSize()); + for (uint32_t i = 0; i < size; i++) { + auto tensor_desc = op_desc->GetInputDescPtrDfault(i); + if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { + *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); + } + } + } + // Output descs + if (op_desc->GetOutputsSize() > 0) { + auto size = static_cast(op_desc->GetOutputsSize()); + for (uint32_t i = 0; i < size; i++) { + auto tensor_desc = op_desc->GetOutputDescPtr(i); + if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { + *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); + } + } + } + + op_def_proto->set_id(op_desc->GetId()); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + op_def_proto->add_subgraph_name(name); + } + OpDescToAttrDef(op_desc, op_def_proto); + } + return true; +} + +void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { + proto::AttrDef key_in; + proto::AttrDef value_in; + auto op_desc_attr = op_def_proto->mutable_attr(); + if (!op_desc->input_name_idx_.empty()) { + for (auto &item : op_desc->input_name_idx_) { + key_in.mutable_list()->add_s(item.first); + value_in.mutable_list()->add_i(item.second); + } + op_desc_attr->insert({"_input_name_key", key_in}); + op_desc_attr->insert({"_input_name_value", value_in}); + } + proto::AttrDef key_out; + proto::AttrDef value_out; + if (!op_desc->output_name_idx_.empty()) { + for (auto &item : op_desc->output_name_idx_) { + key_out.mutable_list()->add_s(item.first); + value_out.mutable_list()->add_i(item.second); + } + op_desc_attr->insert({"_output_name_key", key_out}); + op_desc_attr->insert({"_output_name_value", value_out}); + } + proto::AttrDef opt_input; + if (!op_desc->optional_input_names_.empty()) { + for (auto &item : op_desc->optional_input_names_) { + opt_input.mutable_list()->add_s(item); + } + op_desc_attr->insert({"_opt_input", opt_input}); + } +} + +bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { + if (node == nullptr || op_def_proto == nullptr) { + GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); + return false; + } + if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { + GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); + return false; + } + if (SerializeEdge(node, op_def_proto)) { + return true; + } else { + return false; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, + proto::GraphDef *graph_proto, + bool is_dump) { + if (graph == nullptr || graph_proto == nullptr) { + GELOGE(GRAPH_FAILED, "Input para Invalid"); + return false; + } + graph_proto->set_name(graph->GetName()); + // Inputs + for (const auto &input : graph->GetInputNodes()) { + if (input != nullptr) { + graph_proto->add_input(input->GetName() + ":0"); + } + } + // Outputs + for (const auto &output : graph->GetGraphOutNodesInfo()) { + if (output.first != nullptr) { + graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); + GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); + } + } + if (graph->attrs_.GetProtoMsg() != nullptr) { + *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); + } + for (const auto &node : graph->GetDirectNode()) { + if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { + if (node->GetOpDesc() != nullptr) { + GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); + } + return false; + } + } + return true; +} + +bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { + if (model_proto == nullptr) { + GELOGE(GRAPH_FAILED, "model_proto para Invalid"); + return false; + } + model_proto->set_name(model.GetName()); + model_proto->set_custom_version(model.GetPlatformVersion()); + model_proto->set_version(model.GetVersion()); + if (model.attrs_.GetProtoMsg()) { + *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg(); + } + auto &graph = model.graph_; + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); + return false; + } + if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { + GELOGE(GRAPH_FAILED, "SerializeGraph fail"); + return false; + } + + for (auto subgraph : compute_graph->GetAllSubgraphs()) { + if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { + GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); + return false; + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( + GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { + tensor = std::shared_ptr(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); + if (tensor == nullptr) { + GELOGE(GRAPH_FAILED, "tensor is nullptr"); + return false; + } else { + return true; + } +} + +void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, + std::vector &key_in, + std::vector &key_out, + std::vector &value_in, + std::vector &value_out, + std::vector &opt_input) { + if (!key_in.empty()) { + if (key_in.size() != value_in.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", + key_out.size(), value_in.size()); + } else { + for (uint32_t i = 0; i < key_in.size(); ++i) { + op_desc->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); + } + } + } + if (!key_out.empty()) { + if (key_out.size() != value_out.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", + key_out.size(), value_out.size()); + } else { + for (uint32_t i = 0; i < key_out.size(); ++i) { + op_desc->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); + } + } + } + if (!opt_input.empty()) { + for (const auto &i : opt_input) { + op_desc->optional_input_names_.insert(i); + } + } +} + + +bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { + std::vector opt_input; + std::vector key_in; + std::vector value_in; + if (op_def_proto.attr().count("_opt_input") > 0) { + auto &name_list = op_def_proto.attr().at("_opt_input").list(); + for (const auto &item_s : name_list.s()) { + opt_input.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_opt_input"); + } + if (op_def_proto.attr().count("_input_name_key") > 0) { + auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); + for (const auto &item_s : output_name_key_list.s()) { + key_in.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_key"); + } + if (op_def_proto.attr().count("_input_name_value") > 0) { + auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); + for (const auto &item_i : input_name_value_list.i()) { + value_in.push_back(static_cast(item_i)); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_value"); + } + std::vector key_out; + std::vector value_out; + if (op_def_proto.attr().count("_output_name_key") > 0) { + auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); + for (const auto &item_s : output_name_key_list.s()) { + key_out.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_output_name_key"); + } + if (op_def_proto.attr().count("_output_name_value") > 0) { + auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); + for (const auto &item_i : output_name_value_list.i()) { + value_out.push_back(static_cast(item_i)); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_output_name_value"); + } + + op_desc = std::shared_ptr(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto)); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr."); + + // Input tensor + for (auto &input_desc : *op_def_proto.mutable_input_desc()) { + std::shared_ptr temp_value = + std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + op_desc->inputs_desc_.push_back(temp_value); + } + // Output tensor + for (auto &output_desc : *op_def_proto.mutable_output_desc()) { + std::shared_ptr temp_value = + std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); + GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); + op_desc->outputs_desc_.push_back(temp_value); + } + + op_desc->SetId(op_def_proto.id()); + uint32_t graph_index = 0; + for (const std::string &name : op_def_proto.subgraph_name()) { + op_desc->AddSubgraphName(name); + op_desc->SetSubgraphInstanceName(graph_index++, name); + } + + // insert name index by key and value + AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); + + return true; +} + +bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { + GE_RT_FALSE_CHECK_NOTNULL(graph); + OpDescPtr op_desc = nullptr; + if (!UnserializeOpDesc(op_desc, op_def_proto)) { + GELOGW("UnserializeOpDesc error."); + } + + NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); + + // Inputs + int dst_index = 0; + for (const auto &input : op_def_proto.input()) { + string node_name; + int32_t index = 0; + if (ParseNodeIndex(input, node_name, index)) { + node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()}); + } + if (index >= 0) { + dst_index++; + } + } + node_map_[op_def_proto.name()] = node; + return true; +} + +bool ModelSerializeImp::HandleNodeNameRef() { + // Edges + for (auto &item : node_input_node_names_) { + auto src_node_it = node_map_.find(item.src_node_name); + if (src_node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str()); + return false; + } + GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue); + if (item.src_out_index >= 0) { + auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); + auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); + if (src_anchor == nullptr || dst_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 + } else { + // Control edge + auto src_anchor = src_node_it->second->GetOutControlAnchor(); + auto dst_anchor = item.dst_node->GetInControlAnchor(); + if (src_anchor != nullptr && dst_anchor != nullptr) { + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 + } + } + } + // Graph input + for (auto &item : graph_input_node_names_) { + auto node_it = node_map_.find(item.node_name); + if (node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); + return false; + } + GE_IF_BOOL_EXEC(item.graph == nullptr, continue); + auto ret = item.graph->AddInputNode(node_it->second); + if (ret == nullptr) { + return false; + } + } + // Graph output + for (auto &item : graph_output_node_names_) { + auto node_it = node_map_.find(item.node_name); + if (node_it == node_map_.end()) { + GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); + return false; + } + + GE_IF_BOOL_EXEC(item.graph == nullptr, continue); + auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); + GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); + if (ret == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputNode failed."); + return false; + } + } + node_input_node_names_.clear(); + graph_input_node_names_.clear(); + graph_output_node_names_.clear(); + node_map_.clear(); + return true; +} + +bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map &subgraphs) { + std::queue all_graphs; + all_graphs.emplace(compute_graph); + while (!all_graphs.empty()) { + ComputeGraphPtr graph = all_graphs.front(); + all_graphs.pop(); + + for (const NodePtr &node : graph->GetDirectNode()) { + const OpDescPtr op_desc = node->GetOpDesc(); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + auto it = subgraphs.find(name); + if (it == subgraphs.end()) { + GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", + op_desc->GetName().c_str(), name.c_str(), subgraphs.size()); + return false; + } + + ComputeGraphPtr &subgraph = it->second; + subgraph->SetParentGraph(graph); + subgraph->SetParentNode(node); + compute_graph->AddSubgraph(subgraph->GetName(), subgraph); + all_graphs.emplace(subgraph); + } + } + } + + return true; +} + +bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { + model.name_ = model_proto.name(); + model.version_ = model_proto.version(); + model.platform_version_ = model_proto.custom_version(); + model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr()); + + auto &graphs_proto = *model_proto.mutable_graph(); + if (!graphs_proto.empty()) { + auto &graph_proto = graphs_proto[0]; + ComputeGraphPtr compute_graph_ptr; + if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { + model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); + } + + // 0 is main graph, following is subgraph. + map subgraphs; + for (int idx = 1; idx < graphs_proto.size(); ++idx) { + ComputeGraphPtr subgraph; + ModelSerializeImp impl; + if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { + GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); + return false; + } + + if (!impl.HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); + return false; + } + + subgraphs[subgraph->GetName()] = subgraph; + } + + if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { + GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); + return false; + } + } + + if (!HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); + return false; + } + return true; +} + +bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { + graph = ComGraphMakeShared(graph_proto.name()); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); + return false; + } + + // Inputs + for (auto input : graph_proto.input()) { + string node_name; + int32_t index; + if (ParseNodeIndex(input, node_name, index)) { + graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); + } + } + // Outputs + for (auto output : graph_proto.output()) { + string node_name; + int32_t index; + if (ParseNodeIndex(output, node_name, index)) { + graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); + } + } + graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr()); + for (auto &op_def_proto : *graph_proto.mutable_op()) { + if (!UnserializeNode(graph, op_def_proto)) { + GELOGE(GRAPH_FAILED, "UnserializeNode fail"); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, + proto::GraphDef &graph_proto) { + if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { + GELOGW("UnserializeGraphWithoutEdge fail"); + } + if (!HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail"); + return false; + } + return true; +} + +bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) { + GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null."); + GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null."); + + google::protobuf::io::CodedInputStream coded_stream(data, len); + // 2048M -1 + coded_stream.SetTotalBytesLimit(INT32_MAX, -1); + if (!proto->ParseFromCodedStream(&coded_stream)) { + GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len); + return false; + } + return true; +} + +Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { + proto::ModelDef model_def; + ModelSerializeImp imp; + if (!imp.SerializeModel(model, &model_def, is_dump)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(model_def.ByteSizeLong()); +#else + Buffer buffer(model_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GELOGW("serialize to array fail."); + } + return buffer; +} + +size_t ModelSerialize::GetSerializeModelSize(const Model &model) { + proto::ModelDef model_def; + ModelSerializeImp imp; + if (!imp.SerializeModel(model, &model_def)) { + return 0; + } +#if !defined(__ANDROID__) && !defined(ANDROID) + return model_def.ByteSizeLong(); +#else + return model_def.ByteSize(); +#endif +} + +Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return Model(); + } + + std::shared_ptr model_proto_ptr; + model_proto_ptr = ComGraphMakeShared(); + if (model_proto_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); + return Model(); + } + + auto &model_proto = *model_proto_ptr; + if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return Model(); + } + + Model model; + ModelSerializeImp imp; + imp.SetProtobufOwner(model_proto_ptr); + if (!imp.UnserializeModel(model, model_proto)) { + GELOGE(GRAPH_FAILED, "Unserialize Model fail"); + return Model(); + } + return model; +} + +Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) { + std::shared_ptr model_def_ptr = ComGraphMakeShared(model_def); + GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed"); + + ModelSerializeImp imp; + imp.SetProtobufOwner(model_def_ptr); + Model model; + if (!imp.UnserializeModel(model, *model_def_ptr)) { + GELOGE(GRAPH_FAILED, "Unserialize Model fail"); + return Model(); + } + return model; +} + +Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) { + proto::GraphDef graph_def; + ModelSerializeImp imp; + if (!imp.SerializeGraph(graph, &graph_def)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(graph_def.ByteSizeLong()); +#else + Buffer buffer(graph_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GE_LOGE("serialize to array fail."); + } + + return buffer; +} + +ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return nullptr; + } + + std::shared_ptr graph_proto_ptr; + graph_proto_ptr = ComGraphMakeShared(); + if (graph_proto_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); + return nullptr; + } + proto::GraphDef &graph_proto = *graph_proto_ptr; + if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return nullptr; + } + + ComputeGraphPtr graph; + ModelSerializeImp imp; + imp.SetProtobufOwner(graph_proto_ptr); + if (!imp.UnserializeGraph(graph, graph_proto)) { + return nullptr; + } + return graph; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) { + proto::OpDef op_def; + ModelSerializeImp imp; + if (!imp.SerializeOpDesc(op_desc, &op_def)) { + return Buffer(); + } +#if !defined(__ANDROID__) && !defined(ANDROID) + Buffer buffer(op_def.ByteSizeLong()); +#else + Buffer buffer(op_def.ByteSize()); +#endif + GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); + GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); + auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); + if (ret != true) { + GE_LOGE("serialize to array fail."); + } + + return buffer; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data, + size_t len) { + if (data == nullptr) { + GELOGE(GRAPH_FAILED, "data is nullptr"); + return nullptr; + } + + std::shared_ptr op_def_ptr; + op_def_ptr = ComGraphMakeShared(); + if (op_def_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); + return nullptr; + } + proto::OpDef &op_def = *op_def_ptr; + if (!ReadProtoFromBinaryFile(data, len, &op_def)) { + GELOGE(GRAPH_FAILED, "ParseFromArray fail"); + return nullptr; + } + + OpDescPtr op_desc; + ModelSerializeImp imp; + imp.SetProtobufOwner(op_def_ptr); + if (!imp.UnserializeOpDesc(op_desc, op_def)) { + GELOGW("UnserializeOpDesc error."); + } + return op_desc; +} +} // namespace ge diff --git a/metadef/graph/module.mk b/metadef/graph/module.mk new file mode 100644 index 00000000..1e00b7fc --- /dev/null +++ b/metadef/graph/module.mk @@ -0,0 +1,3 @@ +LOCAL_PATH := $(call my-dir) + +include $(LOCAL_PATH)/graph.mk diff --git a/metadef/graph/node.cc b/metadef/graph/node.cc new file mode 100644 index 00000000..af84edfd --- /dev/null +++ b/metadef/graph/node.cc @@ -0,0 +1,883 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/node.h" +#include +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "external/graph/operator_factory.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_tensor.h" +#include "graph/operator_factory_impl.h" +#include "graph/shape_refiner.h" +#include "utils/ge_ir_utils.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "common/util/error_manager/error_manager.h" + +using std::string; +using std::vector; + +namespace ge { +Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) + : op_(op), + owner_graph_(owner_graph), + in_data_anchors_(), + out_data_anchors_(), + in_control_anchor_(nullptr), + out_control_anchor_(nullptr), + attrs_(), + has_init_(false) { + anchor_status_updated_ = false; +} + +Node::~Node() { + for (const auto &in_data_anchor : in_data_anchors_) { + if (in_data_anchor != nullptr) { + in_data_anchor->UnlinkAll(); + } + } + for (const auto &out_data_anchor : out_data_anchors_) { + if (out_data_anchor != nullptr) { + out_data_anchor->UnlinkAll(); + } + } + if (in_control_anchor_ != nullptr) { + in_control_anchor_->UnlinkAll(); + } + if (out_control_anchor_ != nullptr) { + out_control_anchor_->UnlinkAll(); + } +} + +graphStatus Node::Init() { + if (has_init_) { + return GRAPH_SUCCESS; + } + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + size_t size = op_->GetAllInputsSize(); + for (size_t i = 0; i < size; i++) { + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + } + size = op_->GetOutputsSize(); + for (size_t i = 0; i < size; i++) { + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + out_data_anchors_.push_back(anchor); + } + in_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); + out_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); + if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) { + GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + has_init_ = true; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { + GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); + return op_->GetName(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { + GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); + return op_->GetType(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const { + const auto &attr_map = this->attrs_; + const auto &r_attr_map = r_node.attrs_; + // 1.Verify node's map size + if (attr_map.size() != r_attr_map.size()) { + GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify node's map key, verify values is temporarily not implemented + for (const auto &it : attr_map) { + if (r_attr_map.count(it.first) == 0) { + GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(), + it.first.c_str()); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { + return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || + ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && + IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && + IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && + IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && + IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, + const AnchorPtr &right_anchor, + size_t i) const { + GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false); + GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false); + + const auto anchor_peer_size = left_anchor->GetPeerAnchors().size(); + const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size(); + // Firstly, verify anchor's peer anchors size equal or not + if (anchor_peer_size != right_anchor_peer_size) { + GELOGE(GRAPH_FAILED, + "Size of anchor's peer anchors verify failed, node name: %s " + "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", + this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); + return false; + } + // Secondly, verify anchor's peer anchor owner node equal or not + for (size_t j = 0; j < anchor_peer_size; j++) { + const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); + const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); + if (peer_node == nullptr || r_peer_node == nullptr) { + GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", + this->GetName().c_str(), i, j); + return false; + } + // Determine the connection relationship by linking the node's name + if (peer_node->GetName() != r_peer_node->GetName()) { + GELOGE(GRAPH_FAILED, + "anchor's peer node name verify failed, node name: %s index[%zu]" + "peer node name %s is different from %s at index [%zu].", + this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { + // 1.Verify all in data and control anchors size + const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); + const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); + if (in_data_anchor_size != r_in_data_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto l_in_anchors = this->GetAllInAnchors(); + const auto r_in_anchors = r_node.GetAllInAnchors(); + // Data anchors size equal, all anchors size not equal, means control anchor size not equal + const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; + const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; + if (in_control_anchor_size != r_in_control_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify all in data and control anchors connect info + for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) { + // Verify data anchors + if (i < in_data_anchor_size) { + const auto &in_anchor = l_in_anchors.at(i); + const auto &r_in_anchor = r_in_anchors.at(i); + if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { + GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } else { + // Verify control anchors + const auto &in_control_anchor = l_in_anchors.at(i); + const auto &r_in_control_anchor = r_in_anchors.at(i); + if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { + GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { + // 1.Verify all out data and control anchors size + const auto l_out_data_anchors = this->GetAllOutDataAnchors(); + const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); + const auto out_data_anchor_size = l_out_data_anchors.size(); + const auto r_out_data_anchor_size = r_out_data_anchors.size(); + if (out_data_anchor_size != r_out_data_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto l_out_anchors = this->GetAllOutAnchors(); + const auto r_out_anchors = r_node.GetAllOutAnchors(); + // Data anchors size equal, all anchors size not equal, means control anchor size not equal + const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; + const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; + if (out_control_anchor_size != r_out_control_anchor_size) { + GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + + // 2.Verify all out data and control anchors connect info + for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) { + // Verify data anchors + if (i < out_data_anchor_size) { + const auto &out_anchor = l_out_data_anchors.at(i); + const auto &r_out_anchor = r_out_data_anchors.at(i); + if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { + GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } else { + // Verify control anchors + const auto &out_control_anchor = l_out_anchors.at(i); + const auto &r_out_control_anchor = r_out_anchors.at(i); + if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { + GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { + return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && + NodeOutConnectsAreEqual(r_node)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { + // This function is deprecated, please use other two overloaded functions + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void) out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, + NodePtr input_node) { + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + GE_CHECK_NOTNULL(op_); + auto op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + + if (index < GetAllInDataAnchors().size()) { + (void) out_anchors.at(0)->LinkTo(in_data_anchors_[index]); + } else { + std::shared_ptr + anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void) out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { + // This function is used for ParseWeights. + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) { + GE_CHECK_NOTNULL(input_node); + // Input_node ---> this + auto out_anchors = input_node->GetAllOutDataAnchors(); + if (out_anchors.size() != 1) { + GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + GE_CHECK_NOTNULL(op_); + auto input_op_desc = input_node->GetOpDesc(); + GE_CHECK_NOTNULL(input_op_desc); + auto index = op_->GetInputIndexByName(name); + if (index != -1) { + if (index >= static_cast(in_data_anchors_.size())) { + GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", + op_->GetName().c_str(), name.c_str(), index); + return GRAPH_FAILED; + } + (void) out_anchors.at(0)->LinkTo(in_data_anchors_[index]); + } else { + std::shared_ptr + anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size()); + return GRAPH_FAILED; + } + in_data_anchors_.push_back(anchor); + (void) out_anchors.at(0)->LinkTo(in_data_anchors_.back()); + } + if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "add input desc failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { + return owner_graph_.lock(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { + if (graph == nullptr) { + return GRAPH_PARAM_INVALID; + } + owner_graph_ = graph; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetAnyOwnerComputeGraph(const ComputeGraphPtr &graph) { + owner_graph_ = graph; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInDataAnchors() const { + return Vistor(shared_from_this(), in_data_anchors_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutDataAnchors() const { + return Vistor(shared_from_this(), out_data_anchors_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { + return in_data_anchors_.size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { + return out_data_anchors_.size(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInAnchors() const { + std::vector vec; + // Push back in_data_anchors_ + for (const auto &in_anchor_iter : Vistor(shared_from_this(), in_data_anchors_)) { + auto in_anchor = Anchor::DynamicAnchorCast(in_anchor_iter); + if (in_anchor != nullptr) { + vec.push_back(in_anchor); + } + } + // Push back in_control_anchor_ + if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || + (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { + auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); + if (in_anchor != nullptr) { + vec.push_back(in_anchor); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutAnchors() const { + std::vector vec; + // Push back out_data_anchors_ + for (const auto &out_anchor_iter : Vistor(shared_from_this(), out_data_anchors_)) { + auto out_anchor = Anchor::DynamicAnchorCast(out_anchor_iter); + if (out_anchor != nullptr) { + vec.push_back(out_anchor); + } + } + // Push back out_control_anchor_ + if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 || + out_control_anchor_->GetPeerInDataAnchors().size() > 0) { + auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); + if (out_anchor != nullptr) { + vec.push_back(out_anchor); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { + if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, + "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", + GetName().c_str(), idx, GetType().c_str()); + return nullptr; + } else { + return in_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { + // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ + if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { + GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); + return nullptr; + } else { + // Return control anchor + if (idx == -1) { + auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); + return in_anchor; + } + // Return data anchor + return in_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { + // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ + if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "out_anchor", GetType().c_str(), }); + GELOGE(GRAPH_FAILED, + "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); + return nullptr; + } else { + // Return control anchor + if (idx == -1) { + auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); + return out_anchor; + } + // Return data anchor + return out_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { + if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, + "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", + GetName().c_str(), idx, GetType().c_str()); + return nullptr; + } else { + return out_data_anchors_[idx]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { + return in_control_anchor_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { + return out_control_anchor_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInNodes() const { + std::vector vec; + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + if (in_control_anchor_ != nullptr) { + if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { + return Node::Vistor(shared_from_this(), vec); + } + + auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); + for (const auto &out_anchor : peer_out_anchors) { + GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + + auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, + "in_control_anchor_ peer out control anchors is nullptr"); + auto node = out_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( + std::unordered_set &nodes_seen) const { + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + auto node = out_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { + continue; + } + if (nodes_seen.count(node.get()) == 0) { + return false; + } + } + + if (in_control_anchor_ != nullptr) { + if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { + return true; + } + auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr"); + auto node = out_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { + continue; + } + if (nodes_seen.count(node.get()) == 0) { + return false; + } + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInDataNodes() const { + std::vector vec; + for (const auto &in_anchor : in_data_anchors_) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); + auto anchor_ptr = in_anchor->GetPeerOutAnchor(); + if (anchor_ptr == nullptr) { + continue; + } + auto node = anchor_ptr->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInControlNodes() const { + std::vector vec; + if (in_control_anchor_ != nullptr) { + for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + auto node = peer_in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + if (out_control_anchor_ != nullptr) { + auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors(); + for (const auto &in_control_anchor : peer_in_control_anchors) { + GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue, + "out_control_anchor_ peer in control anchors is nullptr"); + auto node = in_control_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInAllNodes() const { + std::vector vec; + for (const auto &in_node : GetInDataNodes()) { + vec.push_back(in_node); + } + for (const auto &in_control_node : GetInControlNodes()) { + vec.push_back(in_control_node); + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutDataNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { + uint32_t out_nums = 0; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + out_nums += out_anchor->GetPeerInDataNodesSize(); + } + return out_nums; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutControlNodes() const { + std::vector vec; + + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + if (out_control_anchor_ != nullptr) { + for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + return Node::Vistor(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutAllNodes() const { + std::vector vec; + for (const auto &out_anchor : out_data_anchors_) { + GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr"); + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + + if (out_control_anchor_ != nullptr) { + for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); + auto node = in_anchor->GetOwnerNode(); + GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); + vec.push_back(node); + } + } + return Node::Vistor(shared_from_this(), vec); +} + +graphStatus Node::InferShapeAndType() const { + Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op); + return ret; +} + +graphStatus Node::InferOriginFormat() const { + Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + // Get infer func and execute + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + return op_->CallInferFormatFunc(op); +} +graphStatus Node::Verify() const { + const string data_type = "Data"; + const string aipp_data_type = "AippData"; + const string const_type = "Const"; + const string const_type_train = "Constant"; + const string variable_type = "Variable"; + bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + + if (!is_unknown_graph) { + for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); + continue); + bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || + op_->GetType() == const_type || op_->GetType() == variable_type || op_->GetType() == const_type_train || + op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || op_->MutableInputDesc(in_anchor_ptr->GetIdx()) == nullptr || + in_anchor_ptr->GetPeerAnchors().size() > 0; + if (!valid_anchor) { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, + {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); + GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); + return GRAPH_FAILED; + } + } + } + + string frameworkop_type = "FrameworkOp"; + bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; + if (need_update_name) { + auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); + } else { + GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + if (temp_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("Verify UpdateInputName failed"); + } + if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("Verify UpdateOutputName failed"); + } + } + node_op.BreakConnect(); + } + GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); + if (op_->CommonVerify() == GRAPH_SUCCESS) { + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); + auto verify_func = op_->GetVerifyFunc(); + if (verify_func == nullptr) { + verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType()); + } + if (verify_func != nullptr) { + return (graphStatus)verify_func(op_proxy); + } + return GRAPH_SUCCESS; + } else { + GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str()); + return GRAPH_FAILED; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { + GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr"); + GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, + "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), + op_desc->GetInputsSize()); + + GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, + "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), + op_desc->GetOutputsSize()); + op_ = op_desc; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> +Node::GetInDataNodesAndAnchors() const { + std::vector> vec; + for (const auto &p : in_data_anchors_) { + if (p == nullptr) { + GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + auto anchor_ptr = p->GetPeerOutAnchor(); + if (anchor_ptr == nullptr) { + continue; + } + auto node = anchor_ptr->GetOwnerNode(); + if (node == nullptr) { + GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + vec.push_back(std::make_pair(node, anchor_ptr)); + } + return Node::Vistor>(shared_from_this(), vec); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> +Node::GetOutDataNodesAndAnchors() const { + std::vector> vec; + for (const auto &p : out_data_anchors_) { + if (p == nullptr) { + GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + for (const auto &in_anchor : p->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + auto node = in_anchor->GetOwnerNode(); + if (node == nullptr) { + GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); + continue; + } + vec.push_back(std::make_pair(node, in_anchor)); + } + } + return Node::Vistor>(shared_from_this(), vec); +} +} // namespace ge diff --git a/metadef/graph/op_desc.cc b/metadef/graph/op_desc.cc new file mode 100644 index 00000000..fa492a19 --- /dev/null +++ b/metadef/graph/op_desc.cc @@ -0,0 +1,1461 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/op_desc.h" +#include "debug/ge_attr_define.h" +#include "debug/ge_util.h" +#include "external/graph/operator.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/common_error_codes.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/operator_factory_impl.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/ge_ir_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/transformer_utils.h" +#include "proto/ge_ir.pb.h" + +using std::make_pair; +using std::shared_ptr; +using std::string; +using std::vector; + +/*lint -save -e521 -e681 -e732 -e737*/ +namespace ge { +const std::string ATTR_NAME_ID = "id"; + +const std::string ATTR_NAME_STREAM_ID = "stream_id"; + +const std::string ATTR_NAME_INPUT_NAME = "input_name"; + +const std::string ATTR_NAME_SRC_NAME = "src_name"; + +const std::string ATTR_NAME_SRC_INDEX = "src_index"; + +const std::string ATTR_NAME_INPUT = "input"; + +const std::string ATTR_NAME_OUTPUT = "output"; + +const std::string ATTR_NAME_INPUT_DESC = "input_desc"; + +const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; + +const std::string ATTR_NAME_DST_NAME = "dst_name"; + +const std::string ATTR_NAME_DST_INDEX = "dst_index"; + +const std::string ATTR_NAME_WORKSPACE = "workspace"; + +const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; + +const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; + +const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; + +const std::string ATTR_NAME_OP_KERNEL_LIB_NAME = "_ge_attr_op_kernel_lib_name"; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { + op_def_.InitDefault(); + if (op_def_.GetProtoMsg() != nullptr) { + op_def_.GetProtoMsg()->set_has_out_attr(true); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) { + op_def_.InitDefault(); + if (op_def_.GetProtoMsg() != nullptr) { + op_def_.GetProtoMsg()->set_has_out_attr(true); + } + SetName(name); + SetType(type); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner, + ge::proto::OpDef *op_def) + : op_def_(proto_msg_owner, op_def) { + if (op_def != nullptr && !op_def->has_out_attr()) { + op_def->set_has_out_attr(true); + + int64_t id = 0; + (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id); + op_def->set_id(id); + + int64_t stream_id = 0; + (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id); + op_def->set_stream_id(stream_id); + + vector input_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name); + for (auto &item : input_name) { + op_def->add_input_name(item); + } + vector src_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name); + for (auto &item : src_name) { + op_def->add_src_name(item); + } + vector src_index; + (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index); + for (auto &item : src_index) { + op_def->add_src_index(item); + } + vector input; + (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input); + for (auto &item : input) { + op_def->add_input_i(item); + } + vector output; + (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output); + for (auto &item : output) { + op_def->add_output_i(item); + } + vector dst_name; + (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name); + for (auto &item : dst_name) { + op_def->add_dst_name(item); + } + vector dst_index; + (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index); + for (auto &item : dst_index) { + op_def->add_dst_index(item); + } + vector workspace; + (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace); + for (auto &item : workspace) { + op_def->add_workspace(item); + } + vector workspace_bytes; + (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes); + for (auto &item : workspace_bytes) { + op_def->add_workspace_bytes(item); + } + vector is_input_const; + (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); + for (auto item : is_input_const) { + op_def->add_is_input_const(item); + } + auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list(); + if (input_desc_mutable_list != nullptr) { + *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td()); + } + auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list(); + if (output_desc_mutable_list != nullptr) { + *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td()); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->name(); + } + return ""; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_name(name); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->type(); + } + return ""; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_type(type); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { + int index = static_cast(inputs_desc_.size()); + return AddInputDesc("__input" + std::to_string(index), input_desc); +} + +graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) { + graphStatus ret = GRAPH_SUCCESS; + if (index < inputs_desc_.size()) { + // InputsDesc[index] is exist, then update it + ret = UpdateInputDesc(index, input_desc); + } else { + // InputDesc[index] is not exist, then add it + ret = AddInputDesc(input_desc); + } + return ret; +} + +graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { + GELOGI("input %s is exist, update it", name.c_str()); + graphStatus ret = UpdateInputDesc(name, input_desc); + return ret; + } else { + int index = static_cast(inputs_desc_.size()); + std::shared_ptr in_desc = ComGraphMakeShared(input_desc); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + inputs_desc_.push_back(in_desc); + (void)input_name_idx_.insert(make_pair(name, index)); + if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { + register_input_name_.push_back(name); + } + + return GRAPH_SUCCESS; + } +} + +graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { + for (unsigned int i = 0; i < num; i++) { + string input_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", input_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + if (index > inputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); + return GRAPH_FAILED; + } + + (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); + + // Update index in input_name_idx + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { + if (it->second >= (index + i)) { + it->second += 1; + } + } + + (void)input_name_idx_.insert(make_pair(input_name, i + index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { + for (unsigned int i = 0; i < num; i++) { + string output_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", output_name.c_str()); + + std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); + if (out_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + if (index > outputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); + return GRAPH_FAILED; + } + + (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); + + // Update index in input_name_idx + for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { + if (it->second >= (index + i)) { + it->second += 1; + } + } + + (void)output_name_idx_.insert(make_pair(output_name, i + index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { + for (unsigned int i = 0; i < num; i++) { + string input_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", input_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); + + // Update index in input_name_idx + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { + it->second += 1; + } + + (void)input_name_idx_.insert(make_pair(input_name, 0)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) { + for (unsigned int i = 0; i < num; i++) { + string output_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, + "Add output tensor_desc is existed. name[%s]", output_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc); + + // Update index in output_name_idx + for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { + it->second += 1; + } + (void)output_name_idx_.insert(make_pair(output_name, 0)); + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { + if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; + (void)optional_input_names_.insert(name); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { + if (index >= inputs_desc_.size()) { + GELOGW("The index is invalid. index[%u]", index); + return GRAPH_FAILED; + } + + inputs_desc_[index] = ComGraphMakeShared(tensor_Desc); + if (inputs_desc_[index] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { + return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && + IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && + IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && + IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && + IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { + const auto &op_def = this->op_def_.GetProtoMsg(); + const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg(); + if ((op_def != nullptr) && (r_op_def != nullptr)) { + // Message OpDef in ge_ir.proto + return ( + IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && + IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && + IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && + IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && + IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && + IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && + IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && + IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && + IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && + IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && + IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && + IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && + IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && + IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), + "OpDef_.workspace_bytes()") && + IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); + } else { + return ((op_def == nullptr) && (r_op_def == nullptr)); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) + const { + // 1.Verify inputs and outputs desc size + const auto inputs_desc_size = this->inputs_desc_.size(); + const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); + if (inputs_desc_size != r_inputs_desc_size) { + GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + const auto outputs_desc_size = this->outputs_desc_.size(); + const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); + if (outputs_desc_size != r_outputs_desc_size) { + GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str()); + return false; + } + // 2.Verify all inputs desc equal + for (uint32_t i = 0; i < inputs_desc_size; i++) { + const auto &in_ge_tensor_desc = this->GetInputDesc(i); + const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); + // Determine the connection relationship by GeTensorDesc + if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { + GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", + this->GetName().c_str()); + return false; + } + } + // 3.Verify all outputs desc equal + for (uint32_t i = 0; i < outputs_desc_size; i++) { + const auto &out_ge_tensor_desc = this->GetOutputDesc(i); + const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); + if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { + GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", + this->GetName().c_str()); + return false; + } + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { + return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && + OpDescGenTensorDescsAreEqual(r_op_desc)); +} + +graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { + auto it = input_name_idx_.find(name); + if (it == input_name_idx_.end()) { + GELOGW("Cann't find the input desc. name[%s]", name.c_str()); + return GRAPH_FAILED; + } + if (it->second >= inputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); + return GRAPH_FAILED); + inputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); + if (inputs_desc_[it->second] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +bool OpDesc::InputIsSet(const string &name) const { + auto it = input_name_idx_.find(name); + if (it != input_name_idx_.end()) { + GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); + auto tensor_desc = inputs_desc_[it->second]; + GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); + auto dims = tensor_desc->GetShape().GetDims(); + if (dims.size() > 0) { + return true; + } + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc()); + return *(inputs_desc_[index].get()); +} + +GeTensorDesc OpDesc::GetInputDesc(const string &name) const { + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); + GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); + return *(inputs_desc_[it->second].get()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { + GELOGW("input desc is invalid"); + return nullptr; + } + return inputs_desc_[index]; +} + +GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { + auto input_name_idx = GetAllInputName(); + auto it = input_name_idx.find(name); + if (it == input_name_idx.end()) { + GELOGW("Failed to get [%s] input desc", name.c_str()); + return nullptr; + } + return MutableInputDesc(it->second); +} + +GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { + vector names; + if (input_name_idx_.empty()) { + return OpDesc::Vistor(shared_from_this(), names); + } + for (std::pair input : input_name_idx_) { + names.push_back(input.first); + } + return OpDesc::Vistor(shared_from_this(), names); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { + op_kernel_lib_name_ = name; + auto ret = AttrUtils::SetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, name); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set op kernel lib name failed."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { + if (!op_kernel_lib_name_.empty()) { + return op_kernel_lib_name_; + } + string op_kernel_lib_name; + (void)AttrUtils::GetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, op_kernel_lib_name); + return op_kernel_lib_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { + engine_name_ = name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { + vector temp{}; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + temp.push_back(*it); + } else { + GELOGW("this inputDesc is InValid, it won't be return"); + continue; + } + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { + vector temp{}; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + temp.push_back(it); + } else { + GELOGW("this inputDesc is InValid, it won't be return"); + continue; + } + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { + // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. + size_t size = 0; + for (const auto &it : inputs_desc_) { + if (it->IsValid() == GRAPH_SUCCESS) { + size++; + } + } + return size; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { + int index = static_cast(outputs_desc_.size()); + return AddOutputDesc("__output" + std::to_string(index), output_desc); +} + +graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) { + GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED, + "Add output tensor_Desc is existed. name[%s]", name.c_str()); + int index = static_cast(outputs_desc_.size()); + + std::shared_ptr tensor = ComGraphMakeShared(output_desc); + if (tensor == nullptr) { + GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + outputs_desc_.push_back(tensor); + (void)output_name_idx_.insert(make_pair(name, index)); + if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { + register_output_name_.push_back(name); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { + GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); + + outputs_desc_[index] = ComGraphMakeShared(tensor_Desc); + if (outputs_desc_[index] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Cann't find the output desc. name[%s]", name.c_str()); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); + return GRAPH_FAILED); + outputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); + if (outputs_desc_[it->second] == nullptr) { + GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc()); + return *(outputs_desc_[index].get()); +} + +GeTensorDesc OpDesc::GetOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc()); + GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc()); + return *(outputs_desc_[it->second].get()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index); + return outputs_desc_[index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Failed to get [%s] output desc", name.c_str()); + return nullptr; + } + return MutableOutputDesc(it->second); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { + return static_cast(outputs_desc_.size()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { + vector temp{}; + for (const auto &it : outputs_desc_) { + temp.push_back(*it); + } + return OpDesc::Vistor(shared_from_this(), temp); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { + return OpDesc::Vistor(shared_from_this(), outputs_desc_); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); + return outputs_desc_[index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { + GELOGW("inputsDesc[%u] is InValid", index); + return nullptr; + } else { + return inputs_desc_[static_cast(index)]; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr +OpDesc::GetInputDescPtrDfault(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); + return inputs_desc_[(int32_t)index]; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); + return inputs_desc_[it->second]; +} + +graphStatus OpDesc::AddRegisterInputName(const std::string &name) { + if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { + register_input_name_.push_back(name); + } + + return GRAPH_SUCCESS; +} + +vector OpDesc::GetRegisterInputName() const { + return register_input_name_; +} + +graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { + if (is_push_back) { + for (unsigned int i = 0; i < num; i++) { + if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) + return GRAPH_FAILED; + } + } else { + if (AddInputDescForward(name, num) != GRAPH_SUCCESS) + return GRAPH_FAILED; + } + if (AddRegisterInputName(name) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { + if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::AddRegisterOutputName(const string &name) { + if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { + register_output_name_.push_back(name); + } + + return GRAPH_SUCCESS; +} + +vector OpDesc::GetRegisterOutputName() const { + return register_output_name_; +} + +graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { + if (is_push_back) { + for (unsigned int i = 0; i < num; i++) { + if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) + return GRAPH_FAILED; + } + } else { + if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) + return GRAPH_FAILED; + } + + if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +bool OpDesc::IsOptionalInput(const string &name) const { + return optional_input_names_.find(name) != optional_input_names_.end(); +} + +bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } + +std::map OpDesc::GetAllInputName() const { return input_name_idx_; } + +std::map OpDesc::GetAllOutputName() { return output_name_idx_; } + +std::map& OpDesc::MutableAllInputName() { return input_name_idx_; } + +std::map& OpDesc::MutableAllOutputName() { return output_name_idx_; } + +bool OpDesc::UpdateInputName(std::map input_name_idx) { + bool ret = true; + // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. + auto input_map_size = inputs_desc_.size(); + auto factory_map_size = input_name_idx.size(); + // It indicates that some inputs have no optionalname. + // The redundant optionalname of factory needs to be deleted and then assigned + if (input_map_size < factory_map_size) { + GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, + factory_map_size); + for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { + if (it->second >= input_map_size) { + it = input_name_idx.erase(it); + } else { + ++it; + } + } + if (input_name_idx.size() == input_map_size) { + GELOGI("UpdateInputName"); + input_name_idx_ = input_name_idx; + } else { + ret = false; + GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); + } + } else if (input_map_size == factory_map_size) { + input_name_idx_ = input_name_idx; + } else { + ret = false; + GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); + } + return ret; +} + +bool OpDesc::UpdateOutputName(std::map output_name_idx) { + size_t output_map_size = GetAllOutputsDescSize(); + size_t factory_map_size = output_name_idx.size(); + if (output_map_size < factory_map_size) { + GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, + factory_map_size); + for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { + if (it->second >= output_map_size) { + it = output_name_idx.erase(it); + } else { + ++it; + } + } + if (output_name_idx.size() == output_map_size) { + GELOGI("UpdateoutputName"); + output_name_idx_ = output_name_idx; + return true; + } + } else if (output_map_size == factory_map_size) { + output_name_idx_ = output_name_idx; + return true; + } else { + GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); + return false; + } + GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); + return false; +} + +std::function OpDesc::GetInferFunc() const { return infer_func_; } + +std::function OpDesc::GetVerifyFunc() const { return verifier_func_; } + +void OpDesc::AddInferFunc(const std::function &func) { infer_func_ = func; } + +std::function OpDesc::GetInferFormatFunc() const { return infer_format_func_; } + +void OpDesc::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } + +void OpDesc::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } + +graphStatus OpDesc::InferShapeAndType() { + if (infer_func_ == nullptr) { + infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); + if (infer_func_ == nullptr) { + GELOGW("%s does not have inferfunc_.", GetName().c_str()); + /// The infoshape function has not been added for each operator in the current operator information library. + /// No infoshape added operator skips the call + /// and directly uses the shape information passed down by the upper framework + return GRAPH_SUCCESS; + } + } + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); + graphStatus ret = (graphStatus)infer_func_(op_proxy); + op_proxy.BreakConnect(); + return ret; +} + +graphStatus OpDesc::DefaultInferFormat() { + ge::Format first_none_nd_format = FORMAT_ND; + auto input_descs = GetAllInputsDescPtr(); + auto output_descs = GetAllOutputsDescPtr(); + // Overall input and output,get the first non-nd format + for (const auto &input_desc : input_descs) { + Format origin_format = input_desc->GetOriginFormat(); + if (origin_format != FORMAT_ND) { + first_none_nd_format = origin_format; + break; + } + } + for (const auto &output_desc : output_descs) { + Format origin_format = output_desc->GetOriginFormat(); + if (origin_format != FORMAT_ND) { + first_none_nd_format = origin_format; + break; + } + } + // Refresh all input output format + GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); + + for (const auto &input_desc : input_descs) { + Format origin_format = input_desc->GetOriginFormat(); + GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); + if (origin_format == FORMAT_ND) { + input_desc->SetOriginFormat(first_none_nd_format); + input_desc->SetFormat(first_none_nd_format); + } + } + for (const auto &output_desc : output_descs) { + Format origin_format = output_desc->GetOriginFormat(); + GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); + if (origin_format == FORMAT_ND) { + output_desc->SetOriginFormat(first_none_nd_format); + output_desc->SetFormat(first_none_nd_format); + } + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::OpVerify() { + if (verifier_func_ == nullptr) { + verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType()); + } + if (verifier_func_ != nullptr) { + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); + graphStatus ret = (graphStatus)verifier_func_(op_proxy); + op_proxy.BreakConnect(); + return ret; + } + return GRAPH_SUCCESS; +} + +graphStatus OpDesc::CommonVerify() const { + for (const string &iname : GetAllInputNames()) { + // Checking shape of all inputs + vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); + for (int64_t dim : ishape) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(dim < -2, + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); + return GRAPH_FAILED, + "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), iname.c_str()); + } + } + // Check all attributes defined + const auto &all_attributes = GetAllAttrs(); + for (const auto &name : GetAllAttrNames()) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(all_attributes.find(name) == all_attributes.end(), + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {GetName(), "attribute " + name, "is empty"}); + return GRAPH_FAILED, + "operator attribute %s is empty.", name.c_str()); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { + auto it = input_name_idx_.begin(); + for (; it != input_name_idx_.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); + return it->first; +} + +int OpDesc::GetInputIndexByName(const string &name) const { + auto it_find = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); + return static_cast(it_find->second); +} + +int OpDesc::GetValidInputIndexByName(const string &name) const { + map valid_input_name_idx{}; + uint32_t j = 0; + for (size_t i = 0; i < GetAllInputsSize(); i++) { + if (MutableInputDesc(static_cast(i)) != nullptr) { + auto valid_name = GetInputNameByIndex(static_cast(i)); + GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), -1); + valid_input_name_idx.insert({valid_name, j}); + j++; + } + } + auto it_find = valid_input_name_idx.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != valid_input_name_idx.end(), -1); + return static_cast(it_find->second); +} + +string OpDesc::GetValidInputNameByIndex(uint32_t index) const { + map valid_input_name_idx{}; + uint32_t j = 0; + for (size_t i = 0; i < GetAllInputsSize(); i++) { + if (MutableInputDesc(static_cast(i)) != nullptr) { + auto valid_name = GetInputNameByIndex(static_cast(i)); + GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), ""); + valid_input_name_idx.insert({valid_name, j}); + j++; + } + } + auto it = valid_input_name_idx.begin(); + for (; it != valid_input_name_idx.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != valid_input_name_idx.end(), ""); + return it->first; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const { + auto it = output_name_idx_.begin(); + for (; it != output_name_idx_.end(); ++it) { + if (it->second == index) { + break; + } + } + GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); + return it->first; +} + +int OpDesc::GetOutputIndexByName(const string &name) const { + auto it_find = output_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); + return static_cast(it_find->second); +} + +ProtoAttrMapHelper OpDesc::MutableAttrMap() { + if (op_def_.GetProtoMsg() == nullptr) { + GELOGE(GRAPH_FAILED, "op def get proto msg failed"); + return GeIrProtoHelper(); + } + return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr()); +} + +ConstProtoAttrMapHelper OpDesc::GetAttrMap() const { + return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_id(id); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->id(); + } + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->set_stream_id(stream_id); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + return proto_msg->stream_id(); + } + return 0; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector &input_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_input_name(); + for (auto &item : input_name) { + proto_msg->add_input_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputName() const { + vector input_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->input_name()) { + input_name.push_back(item); + } + } + return input_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector &src_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_src_name(); + for (auto &item : src_name) { + proto_msg->add_src_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcName() const { + vector src_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->src_name()) { + src_name.push_back(item); + } + } + return src_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector &src_index) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_src_index(); + for (auto &item : src_index) { + proto_msg->add_src_index(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcIndex() const { + vector src_index; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->src_index()) { + src_index.push_back(item); + } + } + return src_index; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector &input) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_input_i(); + for (auto &item : input) { + proto_msg->add_input_i(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputOffset() const { + vector input; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->input_i()) { + input.push_back(item); + } + } + return input; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector &output) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_output_i(); + for (auto &item : output) { + proto_msg->add_output_i(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOutputOffset() const { + vector output; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->output_i()) { + output.push_back(item); + } + } + return output; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector &dst_name) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_dst_name(); + for (auto &item : dst_name) { + proto_msg->add_dst_name(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName() const { + vector dst_name; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->dst_name()) { + dst_name.push_back(item); + } + } + return dst_name; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { + auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { + vector depend_names; + (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + return depend_names; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_dst_index(); + for (auto &item : dst_index) { + proto_msg->add_dst_index(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstIndex() const { + vector dst_index; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->dst_index()) { + dst_index.push_back(item); + } + } + return dst_index; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector &workspace) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_workspace(); + for (auto &item : workspace) { + proto_msg->add_workspace(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspace() const { + vector workspace; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->workspace()) { + workspace.push_back(item); + } + } + return workspace; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector &workspace_bytes) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_workspace_bytes(); + for (auto &item : workspace_bytes) { + proto_msg->add_workspace_bytes(item); + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspaceBytes() const { + vector workspace_bytes; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto &item : proto_msg->workspace_bytes()) { + workspace_bytes.push_back(item); + } + } + return workspace_bytes; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector &is_input_const) { + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + proto_msg->clear_is_input_const(); + for (auto item : is_input_const) { + proto_msg->add_is_input_const(item); + } + } + // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag + auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set is_input_const fail."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputConst() const { + vector is_input_const; + auto proto_msg = op_def_.GetProtoMsg(); + if (proto_msg != nullptr) { + for (auto item : proto_msg->is_input_const()) { + is_input_const.push_back(item); + } + } + return is_input_const; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, + const int &index) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { + GELOGI("Restore input name index is existed. name[%s]", name.c_str()); + } + (void)input_name_idx_.insert(make_pair(name, index)); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name, + const int &index) { + if (output_name_idx_.find(name) != output_name_idx_.end()) { + GELOGI("Restore output name index is existed. name[%s]", name.c_str()); + } + (void)output_name_idx_.insert(make_pair(name, index)); + return GRAPH_SUCCESS; +} +graphStatus OpDesc::CallInferFunc(Operator &op) { + if (infer_func_ == nullptr) { + infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); + if (infer_func_ == nullptr) { + GELOGW("%s does not have infer func.", GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + } + std::unique_ptr transformer(new(std::nothrow) NodeShapeTransUtils(shared_from_this())); + if (transformer == nullptr) { + GELOGE(GRAPH_FAILED, "Memory alloc failed"); + return GRAPH_FAILED; + } + if (!transformer->CatchFormatAndShape()) { + GELOGE(GRAPH_FAILED, "catch format and shape info failed!"); + return GRAPH_FAILED; + } + graphStatus graph_status = (graphStatus)infer_func_(op); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); + return GRAPH_FAILED; + } + if (!transformer->UpdateFormatAndShape()) { + GELOGE(GRAPH_FAILED, "catch format and shape info failed!"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +graphStatus OpDesc::CallInferFormatFunc(Operator &op) { + if (infer_format_func_ == nullptr) { + infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType()); + if (infer_format_func_ == nullptr) { + return DefaultInferFormat(); + } + } + return (graphStatus)infer_format_func_(op); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { + if (static_cast(index) >= subgraph_instance_names_.size()) { + return ""; + } + return subgraph_instance_names_.at(index); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() + const { + return subgraph_instance_names_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { + for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { + if (*iter == name) { + *iter = ""; + return; + } + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { + GELOGI("Add subgraph name is %s", name.c_str()); + auto iter = subgraph_names_to_index_.find(name); + if (iter != subgraph_names_to_index_.end()) { + GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); + return GRAPH_FAILED; + } + auto size = subgraph_names_to_index_.size(); + subgraph_names_to_index_[name] = size; + subgraph_instance_names_.resize(size + 1); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map & OpDesc::GetSubgraphNameIndexes() + const { + return subgraph_names_to_index_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, const std::string &name) { + GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); + if (index >= subgraph_instance_names_.size()) { + GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); + return GRAPH_PARAM_INVALID; + } + subgraph_instance_names_[index] = name; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +void OpDesc::RegisterSubgraphIrName(const string &name, SubgraphType type) { + subgraph_ir_names_to_type_[name] = type; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +const std::map &OpDesc::GetSubgraphIrNames() const { + return subgraph_ir_names_to_type_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +SubgraphType OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { + auto iter = subgraph_ir_names_to_type_.find(name); + if (iter == subgraph_ir_names_to_type_.end()) { + return kSubgraphTypeEnd; + } + return iter->second; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { + for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { + if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. + continue; + } + + for (auto name_to_index : subgraph_names_to_index_) { + if (name_to_index.second != idx) { // find subgraph name. + continue; + } + + subgraph_name = name_to_index.first; + return GRAPH_SUCCESS; + } + } + + return GRAPH_PARAM_INVALID; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::InferDataSlice() { + if (infer_data_slice_func_ == nullptr) { + infer_data_slice_func_ = OperatorFactoryImpl::GetInferDataSliceFunc(GetType()); + if (infer_data_slice_func_ == nullptr) { + GELOGW("%s does not have infer data slice func.", GetName().c_str()); + return NO_DEPENDENCE_FUNC; + } + } + Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); + graphStatus ret = (graphStatus)infer_data_slice_func_(op_proxy); + op_proxy.BreakConnect(); + return ret; +} +} // namespace ge diff --git a/metadef/graph/operator.cc b/metadef/graph/operator.cc new file mode 100644 index 00000000..06267d84 --- /dev/null +++ b/metadef/graph/operator.cc @@ -0,0 +1,2292 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/operator.h" +#include "external/graph/operator_factory.h" +#include +#include +#include +#include +#include +#include "array_ops.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "external/graph/attr_value.h" +#include "external/graph/types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_context.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/runtime_inference_context.h" +#include "graph/usr_types.h" +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_attr_define.h" +#include "utils/graph_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_adapter.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" +#include +#include +#include +#include + +using std::enable_shared_from_this; +using std::make_pair; +using std::shared_ptr; +using std::string; +using std::to_string; +using std::vector; + +/*lint -save -e529 -e728*/ +namespace ge { +/*lint -e446 -e732*/ +/*lint -e665*/ +class OpIO { + public: + OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} + + ~OpIO() = default; + + string GetName() const { return name_; } + + int GetIndex() const { return index_; } + + OperatorImplPtr GetOwner() const { return owner_; } + + bool operator==(const OpIO &r_value) const { + return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) && + (this->GetOwner() == r_value.GetOwner()); + } + + private: + string name_; + int index_; + std::shared_ptr owner_; +}; + +class TensorTypeImpl { + public: + TensorTypeImpl() = default; + ~TensorTypeImpl() = default; + + std::vector dt_vec_; +}; + +TensorType::TensorType(DataType dt) { + tensor_type_impl_ = ComGraphMakeShared(); + if (tensor_type_impl_ != nullptr) { + tensor_type_impl_->dt_vec_.push_back(dt); + } +} + +TensorType::TensorType(const std::initializer_list &types) { + tensor_type_impl_ = ComGraphMakeShared(); + if (tensor_type_impl_ != nullptr) { + tensor_type_impl_->dt_vec_ = types; + } +} + +class OperatorImpl : public std::enable_shared_from_this { + friend class GraphBuilderImpl; + friend class OpDescUtils; + + public: + explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared(name, type)) { + if (op_desc_ == nullptr) { + GELOGW("OpDesc make shared failed"); + } + } + explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {} + explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) { + if (node_ != nullptr && node_->GetOpDesc() != nullptr) { + op_desc_ = node_->GetOpDesc(); + } + } + ~OperatorImpl() {} + void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) { + GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr."); + + auto src_op_impl = src_oprt.GetOperatorImplPtr(); + GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null."); + GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null."); + GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return, + "The source operator[%s] must has one output", + src_oprt.operator_impl_->op_desc_->GetName().c_str()) + + uint32_t src_index = 0; + string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index); + GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty."); + + OpIO out_handler(src_name, src_index, src_op_impl); + input_link_.insert(std::make_pair(dst_name, out_handler)); + + int dst_index = op_desc_->GetInputIndexByName(dst_name); + GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), + op_desc_->GetName().c_str()); + + bool is_const = false; + if (src_oprt.GetOpType() == CONSTANT) { + is_const = true; + } + auto is_input_const = op_desc_->GetIsInputConst(); + for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { + is_input_const.push_back(false); + } + + is_input_const[dst_index] = is_const; + op_desc_->SetIsInputConst(is_input_const); + + OpIO op_dst(dst_name, dst_index, shared_from_this()); + src_op_impl->UpdateLinkMapImpl(src_name, op_dst); + auto output_desc = src_op_impl->GetOutputDesc(src_name); + auto input_desc = op_desc_->GetInputDesc(dst_name); + if (input_desc.GetFormat() == FORMAT_RESERVED) { + output_desc.SetFormat(FORMAT_ND); + } else { + output_desc.SetFormat(input_desc.GetFormat()); + } + // Fix for linking opdesc + if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), + src_name.c_str()); + return; + } + } + + void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) { + GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr."); + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); + input_link_.insert(std::make_pair(dst_name, *out_handler)); + + string src_name = out_handler->GetName(); + int dst_index = op_desc_->GetInputIndexByName(dst_name); + GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), + op_desc_->GetName().c_str()); + auto out_op_impl = out_handler->GetOwner(); + GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, + "out_handler invalid. name[%s]", dst_name.c_str()); + bool is_const = false; + if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { + is_const = true; + } + auto is_input_const = op_desc_->GetIsInputConst(); + for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { + is_input_const.push_back(false); + } + is_input_const[dst_index] = is_const; + op_desc_->SetIsInputConst(is_input_const); + + OpIO in_handler(dst_name, dst_index, shared_from_this()); + GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); + + out_op_impl->UpdateLinkMapImpl(src_name, in_handler); + auto src_output_desc = out_op_impl->GetOutputDesc(src_name); + auto dst_input_desc = op_desc_->GetInputDesc(dst_name); + if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { + src_output_desc.SetFormat(FORMAT_ND); + } else { + src_output_desc.SetFormat(dst_input_desc.GetFormat()); + } + GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, + "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), + src_name.c_str()); // fix for linking opdesc + } + + void AddControlInputImp(const ge::Operator &src_oprt) { + if (src_oprt.operator_impl_ == nullptr) { + GELOGE(FAILED, "Src operator impl is nullptr"); + return; + } + for (auto &input : control_input_link_) { + if (input.lock() == src_oprt.operator_impl_) { + return; + } + } + control_input_link_.push_back(src_oprt.operator_impl_); + src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); + } + + graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) { + auto out = input_link_.find(dst_name); + if (out == input_link_.end()) { + return GRAPH_FAILED; + } + out_handler = out->second; + return GRAPH_SUCCESS; + } + + graphStatus GetInputConstData(const string &dst_name, Tensor &data) { + auto node_ptr = GetNode(); + if (node_ptr != nullptr) { + // For inner compute graph + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto index = op_desc->GetInputIndexByName(dst_name); + auto in_data_anchor = node_ptr->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + if (peer_node->GetType() == ENTER || peer_node->GetType() == REFENTER) { + auto enter_in_data_anchor = peer_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(enter_in_data_anchor); + auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(enter_peer_out_data_anchor); + peer_node = enter_peer_out_data_anchor->GetOwnerNode(); + } + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + auto const_op_impl = ComGraphMakeShared(peer_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) + && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + auto const_op_impl = ComGraphMakeShared(parent_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } + } + // Try get from runtime inference context + auto session_id = std::to_string(GetContext().SessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { + GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + if (ret == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + } + } else { + // For outer graph + return GetInputConstDataOut(dst_name, data); + } + auto op_name = GetName(); + GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); + return GRAPH_FAILED; + } + + graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) { + ge::OpIO out_handle("", 0, nullptr); + if (GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { + GELOGE(FAILED, "%s get input impl failed", dst_name.c_str()); + return GRAPH_FAILED; + } + if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) { + Operator const_op(out_handle.GetOwner()); + const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); + if (op_desc_impl_type == CONSTANTOP) { + return const_op.GetAttr(op::Constant::name_attr_value(), data); + } else if (op_desc_impl_type == CONSTANT) { + return const_op.GetAttr(op::Const::name_attr_value(), data); + } + } + return GRAPH_FAILED; + } + + bool InputIsSet(const string &name) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr."); + return op_desc_->InputIsSet(name); + } + + string GetName() const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr."); + return op_desc_->GetName(); + } + + GeTensorDesc GetInputDesc(const string &name) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + return op_desc_->GetInputDesc(name); + } + + GeTensorDesc GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + return op_desc_->GetInputDesc(index); + } + + graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr."); + + return op_desc_->UpdateInputDesc(name, tensor_desc); + } + + OutHandler GetOutput(const string &name) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); + + int src_index = op_desc_->GetOutputIndexByName(name); + GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str()); + shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); + if (output_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OpIO make shared failed"); + return nullptr; + } + return output_ptr; + } + + OutHandler GetOutput(uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); + + string name = op_desc_->GetOutputNameByIndex(index); + if (name.empty()) { + GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); + return nullptr; + } + shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); + if (output_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OpIO make shared failed"); + return nullptr; + } + return output_ptr; + } + + GeTensorDesc GetOutputDesc(const string &name) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + + return op_desc_->GetOutputDesc(name); + } + + GeTensorDesc GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); + + return op_desc_->GetOutputDesc(index); + } + + graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + + auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); + if (res == GRAPH_SUCCESS) { + for (auto ol : output_links_[name]) { + if (ol.GetOwner() == nullptr) { + GELOGW("%s get owner is nullptr", ol.GetName().c_str()); + continue; + } + GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED, + "Could not update next operator's input %s.", ol.GetName().c_str()); + } + } + return res; + } + + size_t GetInputsSize() const { + GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); + return op_desc_->GetInputsSize(); + } + + size_t GetOutputsSize() const { + GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); + return op_desc_->GetOutputsSize(); + } + + graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + return op_desc_->SetAttr(name, std::move(attr_value)); + } + + graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const { + GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); + return op_desc_->GetAttr(name, attr_value); + } + + OpDescPtr GetOpDescImpl() const { return op_desc_; } + + void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) { + auto it_find = output_links_.find(src_name); + if (it_find == output_links_.end()) { + std::vector dsts{op_dst}; + output_links_.insert(std::make_pair(src_name, dsts)); + } else { + it_find->second.push_back(op_dst); + } + } + + Operator ToOperator() { return Operator(shared_from_this()); } + + static OpDescPtr GetOpDesc(const Operator &oprt) { + GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); + return oprt.operator_impl_->op_desc_; + } + + void ClearOutputLinks() noexcept { output_links_.clear(); } + + void ClearInputLinks() noexcept { input_link_.clear(); } + + ge::ConstNodePtr GetNode() { return node_; } + + void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; } + + InferenceContextPtr GetInferenceContext() const { return inference_context_; } + + void SubgraphRegister(const std::string &ir_name, bool dynamic) { + op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); + } + + void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { + op_desc_->AddSubgraphName(ir_name); + subgraph_names_to_builders_[ir_name] = nullptr; + } else { + for (uint32_t i = 0; i < count; ++i) { + string key_name = ir_name + std::to_string(i); + op_desc_->AddSubgraphName(key_name); + subgraph_names_to_builders_[key_name] = nullptr; + } + } + } + + void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); + } + + auto it = subgraph_names_to_builders_.find(key_name); + if (it == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); + return; + } + it->second = builder; + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { + string key_name = ir_name; + if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { + key_name += std::to_string(index); + } + + return GetSubgraphBuilder(key_name); + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { + auto iter = subgraph_names_to_builders_.find(name); + if (iter == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); + return nullptr; + } + + return iter->second; + } + + std::vector GetSubgraphNames() const { + std::vector names; + for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { + names.emplace_back(subgraph_name_to_type.first); + } + return names; + } + + size_t GetSubgraphNamesCount() const { + return op_desc_->GetSubgraphIrNames().size(); + } + + OpDescPtr op_desc_ = nullptr; + + private: + ge::ConstNodePtr node_{nullptr}; + ge::InferenceContextPtr inference_context_; + std::map> output_links_{}; + std::map input_link_{}; + std::vector> control_input_link_{}; + std::vector> control_output_link_{}; + std::map subgraph_names_to_builders_; +}; + +// Used to manage OperatorImpl instances created by ge api. +class OperatorKeeper { + private: + OperatorKeeper() = default; + ~OperatorKeeper() { + for (const auto &iter : operators_) { + if (iter) { + iter->ClearInputLinks(); + iter->ClearOutputLinks(); + } + } + } + std::set operators_; + std::mutex mutex_; + + public: + static OperatorKeeper &GetInstance() { + static OperatorKeeper instance; + return instance; + } + void CheckInOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + std::lock_guard lock(mutex_); + operators_.insert(op_impl); + } + } + void CheckOutOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + std::lock_guard lock(mutex_); + operators_.erase(op_impl); + } + } +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { + ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); + if (operator_impl_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return Operator("default"); + } + return operator_impl_ptr->ToOperator(); +} + +Operator::Operator(const std::string &type) { + static uint32_t index = 0; + string name = type + "_" + std::to_string(index++); + operator_impl_ = ComGraphMakeShared(name, type); + if (operator_impl_ == nullptr) { + GELOGW("OperatorImpl make shared failed"); + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); +} + +Operator::Operator(const char *type) { + if (type != nullptr) { + std::string op_type = type; + static uint32_t index = 0; + string name = op_type + "_" + std::to_string(index++); + operator_impl_ = ComGraphMakeShared(name, op_type); + if (operator_impl_ == nullptr) { + GELOGW("OperatorImpl make shared failed"); + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); + } else { + GELOGW("Operator type is nullptr."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { + shared_ptr operator_impl_ptr; + operator_impl_ptr = ComGraphMakeShared(op_desc); + if (operator_impl_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return Operator("default"); + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); + return operator_impl_ptr->ToOperator(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { + return OperatorImpl::GetOpDesc(oprt); +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) { + operator_impl_ = ComGraphMakeShared(name, type); + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return; + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const AscendString &name, const AscendString &type) { + if ((name.GetString() != nullptr) && (type.GetString() != nullptr)) { + string op_name = name.GetString(); + string op_type = type.GetString(); + operator_impl_ = ComGraphMakeShared(op_name, op_type); + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return; + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); + } else { + GELOGW("Operator input parameter is nullptr."); + } +} + +GE_FUNC_HOST_VISIBILITY Operator::Operator(const char *name, const char *type) { + if ((name != nullptr) && (type != nullptr)) { + string op_name = name; + string op_type = type; + operator_impl_ = ComGraphMakeShared(op_name, op_type); + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); + return; + } + OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); + } else { + GELOGW("Operator input parameter is nullptr."); + } +} + +Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } + +bool Operator::IsEmpty() const { + if (operator_impl_ == nullptr) { + return true; + } + return false; +} + +string Operator::GetName() const { + if (operator_impl_ != nullptr) { + return operator_impl_->GetName(); + } + return ""; +} + +graphStatus Operator::GetName(AscendString &name) const { + if (operator_impl_ != nullptr) { + string op_name = operator_impl_->GetName(); + name = op_name.c_str(); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) { + // Describe the connection relationship between operators, no create action + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); + operator_impl_->SetInputImpl(dst_name, src_oprt); + return *this; +} + +GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt) { + GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Operator dst name is nullptr."); + // Describe the connection relationship between operators, no create action + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr."); + std::string dst_op_name = dst_name; + operator_impl_->SetInputImpl(dst_op_name, src_oprt); + return *this; +} + +Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); + operator_impl_->SetInputImpl(dst_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { + auto out_handler = src_oprt.GetOutput(name); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + (void)SetInput(dst_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt, const char *name) { + GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Dst name is nullptr."); + GE_CHK_BOOL_EXEC(name != nullptr, return *this, "Name is nullptr."); + std::string op_name = name; + std::string dst_op_name = dst_name; + auto out_handler = src_oprt.GetOutput(op_name); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "Out_handler is nullptr."); + (void)SetInput(dst_op_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { + auto out_handler = src_oprt.GetOutput(index); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + (void)SetInput(dst_name, out_handler); + return *this; +} + +Operator &Operator::SetInput(const char *dst_name, const ge::Operator &src_oprt, uint32_t index) { + GE_CHK_BOOL_EXEC(dst_name != nullptr, return *this, "Dst name is nullptr."); + auto out_handler = src_oprt.GetOutput(index); + GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); + std::string op_dst_name = dst_name; + (void)SetInput(dst_name, out_handler); + return *this; +} + +Operator &Operator::AddControlInput(const Operator &src_oprt) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return *this; + } + operator_impl_->AddControlInputImp(src_oprt); + return *this; +} + +graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { + GE_CHECK_NOTNULL(operator_impl_); + graphStatus ret = operator_impl_->GetInputConstData(dst_name, data); + if (ret != GRAPH_SUCCESS) { + GELOGW("%s get input const data failed", dst_name.c_str()); + return ret; + } + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputConstData(const char *dst_name, Tensor &data) const { + GE_CHECK_NOTNULL(dst_name); + GE_CHECK_NOTNULL(operator_impl_); + std::string op_dst_name = dst_name; + graphStatus ret = operator_impl_->GetInputConstData(op_dst_name, data); + if (ret != GRAPH_SUCCESS) { + GELOGW("%s get input const data failed", op_dst_name.c_str()); + return ret; + } + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { + GE_CHECK_NOTNULL(operator_impl_); + if (operator_impl_->GetInputConstDataOut(dst_name, data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "%s get input const data out failed", dst_name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +std::shared_ptr Operator::GetNode() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetNode(); +} + +TensorDesc Operator::GetInputDesc(const std::string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); +} + +TensorDesc Operator::GetInputDescByName(const char *name) const { + GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr."); + std::string op_name = name; + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name)); +} + +void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + operator_impl_->SetInferenceContext(inference_context); +} + +InferenceContextPtr Operator::GetInferenceContext() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetInferenceContext(); +} + +TensorDesc Operator::GetInputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); +} + +graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + auto check = operator_impl_->InputIsSet(name); + if (check) + tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); + return check ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus Operator::TryGetInputDesc(const char *name, TensorDesc &tensor_desc) const { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_name = name; + auto check = operator_impl_->InputIsSet(op_name); + if (check) + tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name)); + return check ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::UpdateInputDesc(const char *name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_name = name; + return operator_impl_->UpdateInputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +OutHandler Operator::GetOutput(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetOutput(name); +} + +OutHandler Operator::GetOutput(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); + return operator_impl_->GetOutput(index); +} + +TensorDesc Operator::GetOutputDesc(const std::string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); +} + +TensorDesc Operator::GetOutputDescByName(const char *name) const { + GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr."); + std::string op_name = name; + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name)); +} + +TensorDesc Operator::GetOutputDesc(uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); +} + +graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::UpdateOutputDesc(const char *name, const ge::TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_name = name; + return operator_impl_->UpdateOutputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); +} + +TensorDesc Operator::GetDynamicInputDesc(const char *name, uint32_t index) const { + GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr."); + std::string op_name = name; + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name + std::to_string(index))); +} + +graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateInputDesc(name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::UpdateDynamicInputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + std::string op_name = name; + return operator_impl_->UpdateInputDesc(op_name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); +} + +TensorDesc Operator::GetDynamicOutputDesc(const char *name, uint32_t index) const { + GE_CHK_BOOL_EXEC(name != nullptr, return TensorDesc(), "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "Operator impl is nullptr."); + std::string op_name = name; + return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name + std::to_string(index))); +} + +graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->UpdateOutputDesc(name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::UpdateDynamicOutputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc) { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_name = name; + return operator_impl_->UpdateOutputDesc(op_name + std::to_string(index), + TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); +} + +graphStatus Operator::InferShapeAndType() { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + return operator_impl_->GetOpDescImpl()->CallInferFunc(*this); +} + +graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) { + return GRAPH_FAILED; + } else { + return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify(); + } +} + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); + return operator_impl_->GetInputsSize(); +} + +GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); + return operator_impl_->GetOutputsSize(); +} + +// According to op get the attrs name and type +namespace { +const std::map kAttrTypesMap = { + {GeAttrValue::VT_NONE, "VT_STRING"}, + {GeAttrValue::VT_STRING, "VT_STRING"}, + {GeAttrValue::VT_FLOAT, "VT_FLOAT"}, + {GeAttrValue::VT_BOOL, "VT_BOOL"}, + {GeAttrValue::VT_INT, "VT_INT"}, + {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, + {GeAttrValue::VT_TENSOR, "VT_TENSOR"}, + {GeAttrValue::VT_BYTES, "VT_BYTES"}, + {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, + {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, + {GeAttrValue::VT_LIST_LIST_INT, "VT_LIST_LIST_INT"}, + {GeAttrValue::VT_DATA_TYPE, "VT_DATA_TYPE"}, + {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, + {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, + {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, + {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, + {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, + {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, + {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, + {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, + {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, + {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, + {GeAttrValue::VT_LIST_DATA_TYPE, "VT_LIST_DATA_TYPE"}, +}; +} // namespace +const std::map Operator::GetAllAttrNamesAndTypes() const { + std::map attr_types; + + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr."); + std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); + + map::iterator iter; + for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { + string name = iter->first; + GeAttrValue attr_value = iter->second; + + GeAttrValue::ValueType type = attr_value.GetValueType(); + + auto iter2 = kAttrTypesMap.find(type); + if (iter2 != kAttrTypesMap.end()) { + attr_types[name] = iter2->second; + } + } + + return attr_types; +} + +graphStatus Operator::GetAllAttrNamesAndTypes(std::map &attr_name_types) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); + + map::iterator iter; + for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { + string name = iter->first; + GeAttrValue attr_value = iter->second; + + GeAttrValue::ValueType type = attr_value.GetValueType(); + + auto iter2 = kAttrTypesMap.find(type); + if (iter2 != kAttrTypesMap.end()) { + AscendString temp(name.c_str()); + attr_name_types[temp] = AscendString(iter2->second.c_str()); + } + } + + return GRAPH_SUCCESS; +} + +void Operator::InputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); +} + +void Operator::OptionalInputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, + GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); +} + +void Operator::InferFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); +} + +void Operator::InferFormatFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); +} + +void Operator::VerifierFuncRegister(const std::function &func) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); +} + +void Operator::OutputRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + // [No need to verify return value] + (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); +} + +void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return, + "set int failed"); + (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); +} + +void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { + GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); + operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); +} + +int Operator::GetDynamicInputNum(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num, + "Get %s int failed", name.c_str()); + return num; +} + +int Operator::GetDynamicInputNum(const char *name) const { + GE_CHK_BOOL_EXEC(name != nullptr, return 0, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "Operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + string op_name = name; + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(op_name), num), return num, + "Get %s int failed", op_name.c_str()); + return num; +} + +void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, + "Set %s int failed", name.c_str()); + (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); +} + +int Operator::GetDynamicOutputNum(const string &name) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, + "Get %s int failed", name.c_str()); + return num; +} + +int Operator::GetDynamicOutputNum(const char *name) const { + GE_CHK_BOOL_EXEC(name != nullptr, return 0, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "Operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); + std::string op_name = name; + int num = 0; + GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(op_name), num), return num, + "Get %s int failed", op_name.c_str()); + return num; +} + +void Operator::RequiredAttrRegister(const string &name) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); + operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); +} + +graphStatus Operator::VerifyAll() { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); + + // Check all inputs defined + for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { + GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), + GRAPH_FAILED, "operator input %s is not linked.", iname.c_str()); + vector ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); + for (int64_t dim : ishape) { + GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", + iname.c_str()); + } + } + // Check all attributes defined + const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); + for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { + GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, + "operator attribute %s is empty.", name.c_str()); + } + + return GRAPH_SUCCESS; +} + +string Operator::GetOpType() const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr."); + return OperatorImpl::GetOpDesc(*this)->GetType(); +} + +graphStatus Operator::GetOpType(AscendString &type) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_type = OperatorImpl::GetOpDesc(*this)->GetType(); + type = op_type.c_str(); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { + string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); + return SetInput(dynamic_dst_name, src_oprt); +} + +Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, + const std::string &name) { + string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); + return SetInput(dynamic_dst_name, src_oprt, name); +} + +OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } + +#define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ + Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return *this; \ + } \ + if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("set attr name %s failed.", name.c_str()); \ + } \ + return *this; \ + } \ + Operator &Operator::SetAttr(const char *name, ArgType attr_value) { \ + if (name == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator attr name is nullptr."); \ + return *this; \ + } \ + std::string op_name = name; \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", op_name.c_str()); \ + return *this; \ + } \ + if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \ + GELOGW("set attr name %s failed.", op_name.c_str()); \ + } \ + return *this; \ + } + +#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ + graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return GRAPH_FAILED; \ + } \ + if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("get attr name %s failed.", name.c_str()); \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } \ + graphStatus Operator::GetAttr(const char *name, ArgType attr_value) const { \ + if (name == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator attr name is nullptr."); \ + return GRAPH_FAILED; \ + } \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name); \ + return GRAPH_FAILED; \ + } \ + std::string op_name = name; \ + if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \ + GELOGW("get attr name %s failed.", op_name.c_str()); \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } + +void Operator::BreakConnect() const { + if (operator_impl_ == nullptr) { + GELOGW("operator impl is nullptr."); + return; + } + operator_impl_->ClearInputLinks(); + operator_impl_->ClearOutputLinks(); + OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); +} + +#define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ + void Operator::AttrRegister(const string &name, ArgType attr_value) { \ + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ + return; \ + } \ + if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ + GELOGW("reg attr name %s failed.", name.c_str()); \ + } \ + } // lint !e665 + +OP_ATTR_SET_IMP(int64_t, Int) +OP_ATTR_SET_IMP(int32_t, Int) +OP_ATTR_SET_IMP(uint32_t, Int) +OP_ATTR_GET_IMP(int64_t &, Int) +OP_ATTR_GET_IMP(int32_t &, Int) +OP_ATTR_GET_IMP(uint32_t &, Int) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(const vector &, ListInt) +OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector &, ListInt) +OP_ATTR_GET_IMP(vector> &, ListListInt) +OP_ATTR_SET_IMP(const vector> &, ListListInt) + +OP_ATTR_SET_IMP(float, Float) +OP_ATTR_GET_IMP(float &, Float) +OP_ATTR_SET_IMP(const vector &, ListFloat) +OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 + +OP_ATTR_SET_IMP(bool, Bool) +OP_ATTR_GET_IMP(bool &, Bool) +OP_ATTR_SET_IMP(const vector &, ListBool) +OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 + +OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) +OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 + +OP_ATTR_REG_IMP(int64_t, Int) +OP_ATTR_REG_IMP(const vector &, ListInt) +OP_ATTR_REG_IMP(float, Float) +OP_ATTR_REG_IMP(const vector &, ListFloat) +OP_ATTR_REG_IMP(const string &, Str) +OP_ATTR_REG_IMP(const vector &, ListStr) +OP_ATTR_REG_IMP(bool, Bool) +OP_ATTR_REG_IMP(const vector &, ListBool) +OP_ATTR_REG_IMP(const vector> &, ListListInt) +OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) + +#undef OP_ATTR_SET_IMP +#undef OP_ATTR_GET_IMP +#undef OP_ATTR_REG_IMP + +void Operator::AttrRegister(const string &name, const AscendString &attr_value) { + if (attr_value.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "Attr %s register param is invalid.", name.c_str()); + return; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return; + } + std::string str_attr_value = attr_value.GetString(); + if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, str_attr_value)) { + GELOGW("Reg attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const std::vector &attr_value) { + std::vector str_attr_values; + for (auto &val : attr_value) { + if (val.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "Attr %s register value is invalid.", name.c_str()); + return; + } + str_attr_values.emplace_back(val.GetString()); + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, str_attr_values)) { + GELOGW("Reg attr name %s failed.", name.c_str()); + } +} + +Operator &Operator::SetAttr(const string &name, const string &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("Set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, string &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("Get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("Set attr name %s failed.", name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("Get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const char *name, const char *attr_value) { + if (name == nullptr || attr_value == nullptr) { + GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr."); + return *this; + } + + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + std::string op_attr_value = attr_value; + if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const AscendString &attr_value) { + if (name == nullptr || attr_value.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + std::string op_attr_value = attr_value.GetString(); + if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const char *name, AscendString &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator input parameters is nullptr."); + return GRAPH_FAILED; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + std::string op_attr_value; + if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { + GELOGW("Get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + attr_value = AscendString(op_attr_value.c_str()); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const char *name, const std::vector &attr_values) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + std::vector op_attr_values; + for (auto &attr_value : attr_values) { + if (attr_value.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator ascend string name is nullptr."); + return *this; + } + op_attr_values.emplace_back(attr_value.GetString()); + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const char *name, std::vector &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + std::vector op_attr_values; + if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) { + GELOGW("Get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + for (auto &op_attr_value : op_attr_values) { + attr_value.emplace_back(AscendString(op_attr_value.c_str())); + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); + if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const Tensor &attr_value) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); + if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) { + GELOGW("set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const string &name, const vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + vector val_list; + for (const auto &item : attr_value) { + auto tensor = TensorAdapter::AsGeTensor(item); + val_list.push_back(tensor); + } + if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const vector &attr_value) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + vector val_list; + for (const auto &item : attr_value) { + auto tensor = TensorAdapter::AsGeTensor(item); + val_list.push_back(tensor); + } + if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + ConstGeTensorPtr tensor; + if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + attr_value = TensorAdapter::GeTensor2Tensor(tensor); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char *name, Tensor &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + ConstGeTensorPtr tensor; + if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) { + GELOGW("get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + attr_value = TensorAdapter::GeTensor2Tensor(tensor); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const string &name, vector &attr_value) const { + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + vector val_list; + if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + for (auto &tensor : val_list) { + attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); + } + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char *name, vector &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + vector val_list; + if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) { + GELOGW("get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + for (auto &tensor : val_list) { + attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, + Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const OpBytes &attr_value) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name, + Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + Buffer buffer; + if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + attr_value.clear(); + if (buffer.data() == nullptr) { + GELOGE(GRAPH_FAILED, "buffer data is null."); + return GRAPH_FAILED; + } + attr_value.assign(buffer.data(), buffer.data() + buffer.size()); + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char *name, OpBytes &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + Buffer buffer; + if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name, buffer)) { + GELOGW("Get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + attr_value.clear(); + if (buffer.data() == nullptr) { + GELOGE(GRAPH_FAILED, "Buffer data is null."); + return GRAPH_FAILED; + } + attr_value.assign(buffer.data(), buffer.data() + buffer.size()); + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr."); + (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); + return *this; +} + +Operator &Operator::SetAttr(const char *name, ge::AttrValue &&attrValue) { + GE_CHK_BOOL_EXEC(name != nullptr, return *this, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "Operator impl is nullptr."); + std::string op_name = name; + (void)operator_impl_->SetAttr(op_name, std::move(attrValue.impl->geAttrValue_)); + return *this; +} + +graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const { + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); + return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_); +} + +graphStatus Operator::GetAttr(const char *name, ge::AttrValue &attrValue) const { + GE_CHK_BOOL_EXEC(name != nullptr, return GRAPH_FAILED, "Operator name is nullptr."); + GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "Operator impl is nullptr."); + std::string op_name = name; + return operator_impl_->GetAttr(op_name, attrValue.impl->geAttrValue_); +} + +Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { + if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const std::vector &attr_value) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char *name, std::vector &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + attr_value.clear(); + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { + GELOGW("Get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return *this; + } + if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } + return *this; +} + +Operator &Operator::SetAttr(const char *name, const ge::DataType &attr_value) { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return *this; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return *this; + } + std::string op_name = name; + if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { + GELOGW("Set attr name %s failed.", op_name.c_str()); + } + return *this; +} + +graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("get attr name %s failed.", name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus Operator::GetAttr(const char *name, ge::DataType &attr_value) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return GRAPH_FAILED; + } + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr, name %s.", name); + return GRAPH_FAILED; + } + std::string op_name = name; + if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { + GELOGW("Get attr name %s failed.", op_name.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +void Operator::AttrRegister(const string &name, const std::vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { + GELOGW("set attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const Tensor &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + auto tensor = TensorAdapter::AsGeTensor(attr_value); + if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const vector &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + vector val_list; + for (const auto &item : attr_value) { + val_list.push_back(TensorAdapter::AsGeTensor(item)); + } + if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { + if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, + Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { + GELOGW("reg attr name %s failed.", name.c_str()); + } +} + +void Operator::SubgraphRegister(const std::string &name, bool dynamic) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); +} + +void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphCountRegister(name, count); +} + +void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); + return; + } + operator_impl_->SetSubgraphBuilder(ir_name, index, builder); +} + +std::vector Operator::GetSubgraphNames() const { + return operator_impl_->GetSubgraphNames(); +} + +graphStatus Operator::GetSubgraphNames(std::vector &names) const { + std::vector subgraph_names = operator_impl_->GetSubgraphNames(); + for (auto &subgraph_name : subgraph_names) { + names.emplace_back(subgraph_name.c_str()); + } + return GRAPH_SUCCESS; +} + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return nullptr; + } + return operator_impl_->GetSubgraphBuilder(ir_name, index); +} + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const char *ir_name, uint32_t index) const { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "Operator impl is nullptr."); + return nullptr; + } + if (ir_name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return nullptr; + } + std::string op_ir_name = ir_name; + return operator_impl_->GetSubgraphBuilder(op_ir_name, index); +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { + return GetDynamicSubgraphBuilder(ir_name, 0); +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const char *ir_name) const { + std::string graph_ir_name; + if (ir_name != nullptr) { + graph_ir_name = ir_name; + } + return GetDynamicSubgraphBuilder(graph_ir_name, 0); +} + +Graph Operator::GetSubgraphImpl(const string &name) const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); + return Graph(""); + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); + if (op_desc == nullptr) { + GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); + return Graph(""); + } + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(name); + if (iter == subgraph_names_to_index.end()) { + GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); + return Graph(""); + } + auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); + if (subgraph_instance_name.empty()) { + GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", + name.c_str(), iter->second); + return Graph(""); + } + + auto node = operator_impl_->GetNode(); + if (node == nullptr) { + GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); + return Graph(""); + } + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); + return Graph(""); + } + auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); + if (subgraph == nullptr) { + GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", + name.c_str(), iter->second, subgraph_instance_name.c_str()); + return Graph(""); + } + return GraphUtils::CreateGraphFromComputeGraph(subgraph); +} + +Graph Operator::GetSubgraph(const string &name) const { + return GetSubgraphImpl(name); +} + +Graph Operator::GetSubgraph(const char *name) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Get subgraph failed, name is nullptr."); + return Graph(""); + } + std::string op_name = name; + return GetSubgraphImpl(op_name); +} + +Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { + return GetSubgraph(name + std::to_string(index)); +} + +Graph Operator::GetDynamicSubgraph(const char *name, uint32_t index) const { + if (name == nullptr) { + GELOGE(GRAPH_FAILED, "Operator name is nullptr."); + return Graph(""); + } + std::string op_name = name; + return GetSubgraph(op_name + std::to_string(index)); +} + +size_t Operator::GetSubgraphNamesCount() const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph names count, the operator impl is null"); + return 0; + } + return operator_impl_->GetSubgraphNamesCount(); +} + +class GraphBuilderImpl { + public: + explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { + if (graph_ == nullptr) { + GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); + return; + } + } + + ~GraphBuilderImpl() {} + + ComputeGraphPtr BuildGraph(const std::vector &inputs) { + std::vector vec_inputs; + for (auto &it : inputs) { + auto src_op_impl = it.operator_impl_; + GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null."); + GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null."); + + string type = src_op_impl->op_desc_->GetType(); + auto node_op = ge::OperatorFactory::CreateOperator("node_op", type); + auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + + GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null."); + if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA || + type == VARIABLE || type == INITDATA || type == GETNEXT) { + vec_inputs.push_back(it.operator_impl_); + } else { + GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); + } + } + GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, "User Input do not include operator such as " + "Data, Variable operator or operator that has output but no input."); + auto ret = WalkAllOperators(vec_inputs); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); + + ret = AddEdge(); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed."); + + return graph_; + } + + const std::map &GetAllNodesInfo() const { return all_nodes_info_; } + + private: + graphStatus WalkAllOperators(const std::vector &vec_ops) { + GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.") + std::queue> que; + que.push(vec_ops); + while (!que.empty()) { + auto vec_tem = que.front(); + que.pop(); + for (const auto &op_impl : vec_tem) { + GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") + GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, + "This node %s has created.", op_impl->GetName().c_str()) + auto node_ptr = graph_->AddNode(op_impl->op_desc_); + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); + all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); + + auto &out_links = op_impl->output_links_; + std::vector vec_op_forward{}; + for (const auto &out_link : out_links) { + for (const auto &op_forward : out_link.second) { + vec_op_forward.push_back(op_forward.GetOwner()); + } + } + + auto &out_control_links = op_impl->control_output_link_; + for (const auto &out_link : out_control_links) { + vec_op_forward.push_back(out_link.lock()); + } + que.push(vec_op_forward); + + auto &in_links = op_impl->input_link_; + std::vector vec_op_back_forward{}; + for (const auto &in_link : in_links) { + vec_op_back_forward.push_back(in_link.second.GetOwner()); + } + + auto &in_control_links = op_impl->control_input_link_; + for (const auto &in_link : in_control_links) { + vec_op_back_forward.push_back(in_link.lock()); + } + que.push(vec_op_back_forward); + + if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } + return MoveSubgraphToRoot(graph_); + } + + graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { + const string name = node->GetName(); + for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { + const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); + if (builder == nullptr) { + GELOGW("Node: %s, Has no builder.", name.c_str()); + continue; + } + + Graph graph = builder(); // Build subgraph from user define builder. + const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); + + subgraph->SetParentNode(node); + subgraph->SetParentGraph(graph_); + if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; + } + + graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { + const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); + if (root_graph == nullptr) { + GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); + return GRAPH_FAILED; + } + + if (root_graph == graph) { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } else { + auto subgraphs = graph->GetAllSubgraphs(); + for (auto &subgraph : subgraphs) { + if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + graph->RemoveSubgraph(subgraph->GetName()); + if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; + } + + graphStatus AddEdge() { + for (const auto &node_info : all_nodes_info_) { + auto src_op_impl_ptr = node_info.first; + auto src_node_ptr = node_info.second; + + GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); + auto out_links = src_op_impl_ptr->output_links_; + GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, + "Src operator impl's op_desc is null."); + auto &op_desc = src_op_impl_ptr->op_desc_; + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + for (const auto &out : out_links) { + auto src_idx = op_desc->GetOutputIndexByName(out.first); + GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); + + auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); + GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed."); + + for (const auto &dst_opio : out.second) { + auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); + GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); + + GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); + + auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); + GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); + + auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, + "from node[%s][%d] to node[%s][%d]AddEdge failed.", + src_node_ptr->GetName().c_str(), src_anchor->GetIdx(), + dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); + } + } + auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); + for (const auto &control_out : src_op_impl_ptr->control_output_link_) { + auto dst_node_info = all_nodes_info_.find(control_out.lock()); + if (dst_node_info == all_nodes_info_.end()) { + GELOGE(GRAPH_FAILED, "Find Dst node failed."); + return GRAPH_FAILED; + } + GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); + auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); + auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), + dst_node_info->second->GetType().c_str()); + return ret; + } + } + } + return GRAPH_SUCCESS; + } + + ComputeGraphPtr graph_ = nullptr; + std::map all_nodes_info_{}; +}; + +inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { + for (const auto &graph : compute_graph->GetAllSubgraphs()) { + std::set node_names; + for (auto const &node : graph->GetDirectNode()) { + auto result = node_names.insert(node->GetName()); + if (!result.second) { + GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str()); + return true; + } + } + } + + std::set node_names; + for (auto const &node : compute_graph->GetDirectNode()) { + auto result = node_names.insert(node->GetName()); + if (!result.second) { + GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str()); + return true; + } + } + return false; +} + +ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { + auto graph_builder_impl = GraphBuilderImpl(name); + ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); + GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); + compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); + if (HasSameNameNode(compute_graph)) { + GELOGW("Compute do not allow has same name nodes."); + compute_graph = nullptr; + } + + return compute_graph; +} + +void GraphUtils::BreakConnect(const std::map &all_nodes_infos) { + for (const auto &it : all_nodes_infos) { + OperatorImplPtr op_impl = it.first; + if (op_impl == nullptr) { + GELOGW("operator impl is nullptr."); + continue; + } + op_impl->ClearOutputLinks(); + op_impl->ClearInputLinks(); + OperatorKeeper::GetInstance().CheckOutOperator(op_impl); + } +} +/*lint +e446 +e732*/ +/*lint +e665*/ +} // namespace ge diff --git a/metadef/graph/operator_factory.cc b/metadef/graph/operator_factory.cc new file mode 100644 index 00000000..ec820878 --- /dev/null +++ b/metadef/graph/operator_factory.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "debug/ge_log.h" + +namespace ge { +Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { + return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); +} + +Operator OperatorFactory::CreateOperator(const char *operator_name, const char *operator_type) { + if (operator_name == nullptr || operator_type == nullptr) { + GELOGE(GRAPH_FAILED, "Create Operator input parameter is nullptr."); + return Operator(); + } + std::string op_name = operator_name; + std::string op_type = operator_type; + return OperatorFactoryImpl::CreateOperator(op_name, op_type); +} + +graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { + return OperatorFactoryImpl::GetOpsTypeList(all_ops); +} + +graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { + std::vector all_op_types; + if (OperatorFactoryImpl::GetOpsTypeList(all_op_types) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get ops type list failed."); + return GRAPH_FAILED; + } + for (auto &op_type : all_op_types) { + all_ops.emplace_back(op_type.c_str()); + } + return GRAPH_SUCCESS; +} + +bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); } + +bool OperatorFactory::IsExistOp(const char *operator_type) { + if (operator_type == nullptr) { + GELOGE(GRAPH_FAILED, "Operator type is nullptr."); + return false; + } + std::string op_type = operator_type; + return OperatorFactoryImpl::IsExistOp(op_type); +} + +OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) { + (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); +} + +OperatorCreatorRegister::OperatorCreatorRegister(const char *operator_type, OpCreatorV2 const &op_creator) { + std::string op_type; + if (operator_type != nullptr) { + op_type = operator_type; + } + (void)OperatorFactoryImpl::RegisterOperatorCreator(op_type, op_creator); +} + +InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, + const InferShapeFunc &infer_shape_func) { + (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); +} + +InferShapeFuncRegister::InferShapeFuncRegister(const char *operator_type, + const InferShapeFunc &infer_shape_func) { + std::string op_type; + if (operator_type != nullptr) { + op_type = operator_type; + } + (void)OperatorFactoryImpl::RegisterInferShapeFunc(op_type, infer_shape_func); +} + +InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, + const InferFormatFunc &infer_format_func) { + (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); +} + +InferFormatFuncRegister::InferFormatFuncRegister(const char *operator_type, + const InferFormatFunc &infer_format_func) { + std::string op_type; + if (operator_type != nullptr) { + op_type = operator_type; + } + (void)OperatorFactoryImpl::RegisterInferFormatFunc(op_type, infer_format_func); +} + +VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { + (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); +} + +VerifyFuncRegister::VerifyFuncRegister(const char *operator_type, const VerifyFunc &verify_func) { + std::string op_type; + if (operator_type != nullptr) { + op_type = operator_type; + } + (void)OperatorFactoryImpl::RegisterVerifyFunc(op_type, verify_func); +} +} // namespace ge diff --git a/metadef/graph/operator_factory_impl.cc b/metadef/graph/operator_factory_impl.cc new file mode 100644 index 00000000..2393c0d8 --- /dev/null +++ b/metadef/graph/operator_factory_impl.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "debug/ge_log.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +shared_ptr> OperatorFactoryImpl::operator_creators_; +shared_ptr> OperatorFactoryImpl::operator_creators_v2_; +shared_ptr> OperatorFactoryImpl::operator_infershape_funcs_; +shared_ptr> OperatorFactoryImpl::operator_inferformat_funcs_; +shared_ptr> OperatorFactoryImpl::operator_verify_funcs_; +shared_ptr> OperatorFactoryImpl::operator_infer_data_slice_funcs_; + +Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { + if (operator_creators_v2_ != nullptr) { + auto it_v2 = operator_creators_v2_->find(operator_type); + if (it_v2 != operator_creators_v2_->end()) { + return it_v2->second(operator_name.c_str()); + } else { + GELOGW("No OpProto of [%s] registered by AscendString.", operator_type.c_str()); + } + } + if (operator_creators_ == nullptr) { + return Operator(); + } + auto it = operator_creators_->find(operator_type); + if (it == operator_creators_->end()) { + GELOGW("no OpProto of [%s] registered by string.", operator_type.c_str()); + return Operator(); + } + return it->second(operator_name); +} + +graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector &all_ops) { + all_ops.clear(); + if (operator_creators_v2_ != nullptr) { + for (auto it_v2 = operator_creators_v2_->begin(); it_v2 != operator_creators_v2_->end(); ++it_v2) { + all_ops.emplace_back(it_v2->first); + } + return GRAPH_SUCCESS; + } else { + GELOGW("Ops not registered by AscendString."); + } + + if (operator_creators_ != nullptr) { + for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) { + all_ops.emplace_back(it->first); + } + } else { + GELOGE(GRAPH_FAILED, "no operator creators found"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +bool OperatorFactoryImpl::IsExistOp(const string &operator_type) { + if (operator_creators_v2_ != nullptr) { + auto it_v2 = operator_creators_v2_->find(operator_type); + if (it_v2 != operator_creators_v2_->end()) { + return true; + } + } + + if (operator_creators_ == nullptr) { + return false; + } + auto it = operator_creators_->find(operator_type); + if (it == operator_creators_->end()) { + return false; + } + return true; +} + +InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { + if (operator_infershape_funcs_ == nullptr) { + return nullptr; + } + auto it = operator_infershape_funcs_->find(operator_type); + if (it == operator_infershape_funcs_->end()) { + return nullptr; + } + return it->second; +} + +InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { + if (operator_inferformat_funcs_ == nullptr) { + GELOGI("operator_inferformat_funcs_ is null"); + return nullptr; + } + auto it = operator_inferformat_funcs_->find(operator_type); + if (it == operator_inferformat_funcs_->end()) { + return nullptr; + } + return it->second; +} + +VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { + if (operator_verify_funcs_ == nullptr) { + return nullptr; + } + auto it = operator_verify_funcs_->find(operator_type); + if (it == operator_verify_funcs_->end()) { + return nullptr; + } + return it->second; +} + +InferDataSliceFunc OperatorFactoryImpl::GetInferDataSliceFunc(const std::string &operator_type) { + if (operator_infer_data_slice_funcs_ == nullptr) { + return nullptr; + } + auto it = operator_infer_data_slice_funcs_->find(operator_type); + if (it == operator_infer_data_slice_funcs_->end()) { + return nullptr; + } + return it->second; +} + +graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { + if (operator_creators_ == nullptr) { + operator_creators_.reset(new (std::nothrow) std::map()); + } + auto it = operator_creators_->find(operator_type); + if (it != operator_creators_->end()) { + return GRAPH_FAILED; + } + (void)operator_creators_->emplace(operator_type, op_creator); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreatorV2 const &op_creator) { + if (operator_creators_v2_ == nullptr) { + operator_creators_v2_.reset(new (std::nothrow) std::map()); + } + auto it = operator_creators_v2_->find(operator_type); + if (it != operator_creators_v2_->end()) { + return GRAPH_FAILED; + } + (void)operator_creators_v2_->emplace(operator_type, op_creator); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, + InferShapeFunc const infer_shape_func) { + if (operator_infershape_funcs_ == nullptr) { + GELOGI("operator_infershape_funcs_ init"); + operator_infershape_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_infershape_funcs_->find(operator_type); + if (it != operator_infershape_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, + InferFormatFunc const infer_format_func) { + if (operator_inferformat_funcs_ == nullptr) { + GELOGI("operator_inferformat_funcs_ init"); + operator_inferformat_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_inferformat_funcs_->find(operator_type); + if (it != operator_inferformat_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { + if (operator_verify_funcs_ == nullptr) { + GELOGI("operator_verify_funcs_ init"); + operator_verify_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_verify_funcs_->find(operator_type); + if (it != operator_verify_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_verify_funcs_->emplace(operator_type, verify_func); + return GRAPH_SUCCESS; +} + +graphStatus OperatorFactoryImpl::RegisterInferDataSliceFunc(const std::string &operator_type, + InferDataSliceFunc const infer_data_slice_func) { + if (operator_infer_data_slice_funcs_ == nullptr) { + GELOGI("operator_infer_data_slice_funcs_ init"); + operator_infer_data_slice_funcs_.reset(new (std::nothrow) std::map()); + } + auto it = operator_infer_data_slice_funcs_->find(operator_type); + if (it != operator_infer_data_slice_funcs_->end()) { + return GRAPH_FAILED; + } + (void)operator_infer_data_slice_funcs_->emplace(operator_type, infer_data_slice_func); + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/opsproto/opsproto_manager.cc b/metadef/graph/opsproto/opsproto_manager.cc new file mode 100644 index 00000000..7ab4ce06 --- /dev/null +++ b/metadef/graph/opsproto/opsproto_manager.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/opsproto_manager.h" +#include +#include +#include +#include +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "mmpa/mmpa_api.h" + +namespace ge { +OpsProtoManager *OpsProtoManager::Instance() { + static OpsProtoManager instance; + return &instance; +} + +bool OpsProtoManager::Initialize(const std::map &options) { + std::lock_guard lock(mutex_); + + if (is_init_) { + GELOGI("OpsProtoManager is already initialized."); + return true; + } + + /*lint -e1561*/ + auto proto_iter = options.find("ge.opsProtoLibPath"); + /*lint +e1561*/ + if (proto_iter == options.end()) { + GELOGW("ge.opsProtoLibPath option not set, return."); + return false; + } + + pluginPath_ = proto_iter->second; + LoadOpsProtoPluginSo(pluginPath_); + + is_init_ = true; + + return true; +} + +void OpsProtoManager::Finalize() { + std::lock_guard lock(mutex_); + + if (!is_init_) { + GELOGI("OpsProtoManager is not initialized."); + return; + } + + for (auto handle : handles_) { + if (handle != nullptr) { + if (mmDlclose(handle) != 0) { + const char *error = mmDlerror(); + error = (error == nullptr) ? "" : error; + GELOGW("failed to close handle, message: %s", error); + continue; + } + GELOGI("close opsprotomanager handler success"); + } else { + GELOGW("close opsprotomanager handler failure, handler is nullptr"); + } + } + + is_init_ = false; +} + +static std::vector Split(const std::string &str, char delim) { + std::vector elems; + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; +} + +static void FindParserSo(const std::string &path, std::vector &file_list) { + // Lib plugin path not exist + if (path.empty()) { + GELOGI("realPath is empty"); + return; + } + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= MMPA_MAX_PATH, return, "path is invalid"); + + char resolved_path[MMPA_MAX_PATH] = {0}; + + // Nullptr is returned when the path does not exist or there is no permission + // Return absolute path when path is accessible + INT32 result = mmRealPath(path.c_str(), resolved_path, MMPA_MAX_PATH); + if (result != EN_OK) { + GELOGW("the path [%s] not exsit.", path.c_str()); + return; + } + + INT32 is_dir = mmIsDir(resolved_path); + // Lib plugin path not exist + if (is_dir != EN_OK) { + GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path); + return; + } + + mmDirent **entries = nullptr; + auto ret = mmScandir(resolved_path, &entries, nullptr, nullptr); + if (ret < EN_OK) { + GELOGW("scan dir failed. path = %s, ret = %d", resolved_path, ret); + return; + } + for (int i = 0; i < ret; ++i) { + mmDirent *dir_ent = entries[i]; + std::string name = std::string(dir_ent->d_name); + if (strcmp(name.c_str(), ".") == 0 || strcmp(name.c_str(), "..") == 0) { + continue; + } + std::string full_name = path + "/" + name; + const std::string so_suff = ".so"; + + if (dir_ent->d_type != DT_DIR && name.size() >= so_suff.size() && + name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { + file_list.push_back(full_name); + GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str()); + } + } + mmScandirFree(entries, ret); + GELOGI("Found %d libs.", ret); +} + +static void GetPluginSoFileList(const std::string &path, std::vector &file_list) { + // Support multi lib directory with ":" as delimiter + std::vector v_path = Split(path, ':'); + + for (size_t i = 0; i < v_path.size(); ++i) { + FindParserSo(v_path[i], file_list); + GELOGI("OpsProtoManager full name = %s", v_path[i].c_str()); + } +} + +void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { + if (path.empty()) { + GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str()); + return; + } + std::vector file_list; + + // If there is .so file in the lib path + GetPluginSoFileList(path, file_list); + + // Not found any .so file in the lib path + if (file_list.empty()) { + GELOGW("OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); + return; + } + // Warning message + GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); + + // Load .so file + for (auto elem : file_list) { + void *handle = mmDlopen(elem.c_str(), MMPA_RTLD_NOW | MMPA_RTLD_GLOBAL); + if (handle == nullptr) { + const char *error = mmDlerror(); + error = (error == nullptr) ? "" : error; + GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), error); + continue; + } else { + // Close dl when the program exist, not close here + GELOGI("OpsProtoManager plugin load %s success.", elem.c_str()); + handles_.push_back(handle); + } + } +} +} // namespace ge diff --git a/metadef/graph/option/ge_context.cc b/metadef/graph/option/ge_context.cc new file mode 100644 index 00000000..ee9f4a1a --- /dev/null +++ b/metadef/graph/option/ge_context.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./ge_context.h" +#include "./ge_global_options.h" +#include "./ge_local_context.h" +#include "framework/common/ge_types.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace { +const int64_t kMinTrainingTraceJobId = 256; +const int kDecimal = 10; +const char *kHostExecPlacement = "HOST"; +} +GEContext &GetContext() { + static GEContext ge_context{}; + return ge_context; +} + +thread_local uint64_t GEContext::session_id_; + +graphStatus GEContext::GetOption(const std::string &key, std::string &option) { + return GetThreadLocalContext().GetOption(key, option); +} + +bool GEContext::GetHostExecFlag() { + std::string exec_placement; + if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { + GELOGW("get option OPTION_EXEC_PLACEMENT failed."); + return false; + } + GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); + return exec_placement == kHostExecPlacement; +} + +std::map &GetMutableGlobalOptions() { + static std::map global_options{}; + return global_options; +} + +void GEContext::Init() { + string session_id; + (void)GetOption("ge.exec.sessionId", session_id); + try{ + session_id_ = static_cast(std::stoi(session_id.c_str())); + } catch (std::invalid_argument &) { + GELOGW("%s transform to int failed.", session_id.c_str()); + } catch (std::out_of_range &) { + GELOGW("%s transform to int failed.", session_id.c_str()); + } + + string device_id; + (void)GetOption("ge.exec.deviceId", device_id); + try{ + device_id_ = static_cast(std::stoi(device_id.c_str())); + } catch (std::invalid_argument &) { + GELOGW("%s transform to int failed.", device_id.c_str()); + } catch (std::out_of_range &) { + GELOGW("%s transform to int failed.", device_id.c_str()); + } + + string job_id; + (void)GetOption("ge.exec.jobId", job_id); + std::string s_job_id = ""; + for (auto c : job_id) { + if (c >= '0' && c <= '9') { + s_job_id += c; + } + } + if (s_job_id == "") { + trace_id_ = kMinTrainingTraceJobId; + return; + } + int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); + if (d_job_id < kMinTrainingTraceJobId) { + trace_id_ = d_job_id + kMinTrainingTraceJobId; + } else { + trace_id_ = d_job_id; + } +} + +uint64_t GEContext::SessionId() { return session_id_; } + +uint32_t GEContext::DeviceId() { return device_id_; } + +uint64_t GEContext::TraceId() { return trace_id_; } + +void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } + +void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } + +} // namespace ge diff --git a/metadef/graph/option/ge_local_context.cc b/metadef/graph/option/ge_local_context.cc new file mode 100644 index 00000000..3a441eb6 --- /dev/null +++ b/metadef/graph/option/ge_local_context.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "./ge_local_context.h" +#include + +namespace ge { +namespace { +thread_local GEThreadLocalContext thread_context; +} + +GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } + +graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { + auto graph_iter = graph_options_.find(key); + if (graph_iter != graph_options_.end()) { + option = graph_iter->second; + return GRAPH_SUCCESS; + } + auto session_iter = session_options_.find(key); + if (session_iter != session_options_.end()) { + option = session_iter->second; + return GRAPH_SUCCESS; + } + auto global_iter = global_options_.find(key); + if (global_iter != global_options_.end()) { + option = global_iter->second; + return GRAPH_SUCCESS; + } + return GRAPH_PARAM_INVALID; +} + +void GEThreadLocalContext::SetGlobalOption(map options_map) { + global_options_.clear(); + global_options_ = std::move(options_map); +} + +void GEThreadLocalContext::SetSessionOption(map options_map) { + session_options_.clear(); + session_options_ = std::move(options_map); +} + +void GEThreadLocalContext::SetGraphOption(map options_map) { + graph_options_.clear(); + graph_options_ = std::move(options_map); +} + +map GEThreadLocalContext::GetAllGraphOptions() const { + return graph_options_; +} + +map GEThreadLocalContext::GetAllSessionOptions() const { + return session_options_; +} + +map GEThreadLocalContext::GetAllGlobalOptions() const { + return global_options_; +} + +map GEThreadLocalContext::GetAllOptions() const { + map options_all; + options_all.insert(graph_options_.begin(), graph_options_.end()); + options_all.insert(session_options_.begin(), session_options_.end()); + options_all.insert(global_options_.begin(), global_options_.end()); + return options_all; +} +} // namespace ge diff --git a/metadef/graph/proto/dump_task.proto b/metadef/graph/proto/dump_task.proto new file mode 100644 index 00000000..b1e346cd --- /dev/null +++ b/metadef/graph/proto/dump_task.proto @@ -0,0 +1,111 @@ +syntax = "proto3"; +package toolkit.dumpdata; + +enum OutputDataType { + DT_UNDEFINED = 0; + DT_FLOAT = 1; + DT_FLOAT16 = 2; + DT_INT8 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_UINT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT32 = 9; + DT_UINT64 = 10; + DT_BOOL = 11; + DT_DOUBLE = 12; + DT_STRING = 13; + DT_DUAL_SUB_INT8 = 14; + DT_DUAL_SUB_UINT8 = 15; + DT_COMPLEX64 = 16; + DT_COMPLEX128 = 17; + DT_QINT8 = 18; + DT_QINT16 = 19; + DT_QINT32 = 20; + DT_QUINT8 = 21; + DT_QUINT16 = 22; + DT_RESOURCE = 23; + DT_STRING_REF = 24; + DT_DUAL = 25; +} + +enum OutputFormat { + FORMAT_NCHW = 0; + FORMAT_NHWC = 1; + FORMAT_ND = 2; + FORMAT_NC1HWC0 = 3; + FORMAT_FRACTAL_Z = 4; + FORMAT_NC1C0HWPAD = 5; + FORMAT_NHWC1C0 = 6; + FORMAT_FSR_NCHW = 7; + FORMAT_FRACTAL_DECONV = 8; + FORMAT_C1HWNC0 = 9; + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; + FORMAT_NC1HWC0_C04 = 12; + FORMAT_FRACTAL_Z_C04 = 13; + FORMAT_CHWN = 14; + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; + FORMAT_HWCN = 16; + FORMAT_NC1KHKWHWC0 = 17; + FORMAT_BN_WEIGHT = 18; + FORMAT_FILTER_HWCK = 19; + FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; + FORMAT_HASHTABLE_LOOKUP_KEYS = 21; + FORMAT_HASHTABLE_LOOKUP_VALUE = 22; + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; + FORMAT_HASHTABLE_LOOKUP_HITS=24; + FORMAT_C1HWNCoC0 = 25; + FORMAT_MD = 26; + FORMAT_NDHWC = 27; + FORMAT_FRACTAL_ZZ = 28; + FORMAT_FRACTAL_NZ = 29; + FORMAT_RESERVED = 30; +} + +message OriginalOp { + string name = 1; + uint32 output_index = 2; + OutputDataType data_type = 3; + OutputFormat format = 4; +} + +message Shape { + repeated uint64 dim = 1; +} + +message OpOutput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + OriginalOp original_op = 4; // the original op corresponding to the output + bytes data = 5; + uint64 size = 6; +} + +message OpInput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + bytes data = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + bytes data = 2; + uint64 size = 3; +} + +message DumpData{ + string version = 1; + uint64 dump_time = 2; + repeated OpOutput output = 3; + repeated OpInput input = 4; + repeated OpBuffer buffer = 5; +} diff --git a/metadef/graph/proto/fusion_model.proto b/metadef/graph/proto/fusion_model.proto new file mode 100644 index 00000000..c92c5581 --- /dev/null +++ b/metadef/graph/proto/fusion_model.proto @@ -0,0 +1,21 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +import "om.proto"; + +package domi; + +message FusionModelDef { + string version = 1; + repeated OpDef fusion_op = 2; +} \ No newline at end of file diff --git a/metadef/graph/proto/fwk_adapter.proto b/metadef/graph/proto/fwk_adapter.proto new file mode 100644 index 00000000..9335c926 --- /dev/null +++ b/metadef/graph/proto/fwk_adapter.proto @@ -0,0 +1,37 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package aicpu.FWKAdapter; +option cc_enable_arenas = true; + + +// Defines an struct for input and output. +message TensorDataInfo { + + // value DataType + uint32 dtype = 1; + + // shape dim + repeated int64 dim = 2; + + // data point addr + int64 data_addr = 3; +} + +message KernelRunParam { + // input + repeated TensorDataInfo input = 1; + // output + repeated TensorDataInfo output = 2; +} + diff --git a/metadef/graph/proto/ge_api.proto b/metadef/graph/proto/ge_api.proto new file mode 100644 index 00000000..331c5aea --- /dev/null +++ b/metadef/graph/proto/ge_api.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; +package ge.api_pb; + +import "ge_ir.proto"; + +// GE initialize +message GEInitialize { + map options = 1; +}; + +// initialize response +message GEInitializeResponse { + uint32 status = 1; + uint32 clientId = 2; +}; + +// GE finalize +message GEFinalize { + bool final = 1; + uint32 clientId = 2; +}; + +message GEFinalizeResponse { + uint32 status = 1; +}; + +// GE Session +message CreateSession{ + map options = 1; +}; + +message CreateSessionResponse { + uint32 status = 1; + uint64 sessionId = 2; +}; + +//GE AddGraph +//model serialize :: serializegraph +message SessionAddGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + ge.proto.GraphDef graph = 3; +}; + +message SessionAddGraphResponse { + uint32 status = 1; +}; + +//GE SessionRemoveGraph +message SessionRemoveGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; +}; + +message SessionRemoveGraphResponse { + uint32 status = 1; +}; + +message SessionRunGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; +}; + +message SessionBuildGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; + string savePath = 4; +}; + +message SessionRunGraphResponse { + uint32 status = 1; + repeated ge.proto.TensorDef tensor = 2; +}; + +message SessionBuildGraphResponse { + uint32 status = 1; +}; + +message DestroySession{ + bool final = 1; + uint64 sessionId = 2; +}; + +message DestroySessionResponse { + uint32 status = 1; +}; diff --git a/metadef/graph/proto/ge_ir.proto b/metadef/graph/proto/ge_ir.proto new file mode 100644 index 00000000..e7bfe0cb --- /dev/null +++ b/metadef/graph/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/metadef/graph/proto/ge_onnx.proto b/metadef/graph/proto/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/graph/proto/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/graph/proto/insert_op.proto b/metadef/graph/proto/insert_op.proto new file mode 100644 index 00000000..bf918b20 --- /dev/null +++ b/metadef/graph/proto/insert_op.proto @@ -0,0 +1,139 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // related_input_name is optional and the top name of data node which inserts aipp + string related_input_name = 6; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/metadef/graph/proto/module.mk b/metadef/graph/proto/module.mk new file mode 100644 index 00000000..dabcde4e --- /dev/null +++ b/metadef/graph/proto/module.mk @@ -0,0 +1,3 @@ +LOCAL_PATH := $(call my-dir) + +include $(LOCAL_PATH)/proto_common.mk diff --git a/metadef/graph/proto/om.proto b/metadef/graph/proto/om.proto new file mode 100644 index 00000000..e15e5f80 --- /dev/null +++ b/metadef/graph/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/metadef/graph/proto/op_mapping_info.proto b/metadef/graph/proto/op_mapping_info.proto new file mode 100644 index 00000000..e23b7ebe --- /dev/null +++ b/metadef/graph/proto/op_mapping_info.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/metadef/graph/proto/proto_common.mk b/metadef/graph/proto/proto_common.mk new file mode 100644 index 00000000..25640da7 --- /dev/null +++ b/metadef/graph/proto/proto_common.mk @@ -0,0 +1,88 @@ +LOCAL_PATH := $(call my-dir) + +COMMON_LOCAL_SRC_FILES := \ + om.proto \ + ge_ir.proto\ + ge_onnx.proto\ + insert_op.proto \ + task.proto \ + fwk_adapter.proto \ + op_mapping_info.proto \ + +COMMON_LOCAL_C_INCLUDES := \ + inc \ + inc/external \ + inc/external/graph \ + inc/common \ + inc/graph \ + common \ + common/graph \ + third_party/protobuf/include \ + libc_sec/include \ + ops/built-in/op_proto/inc \ + cann/ops/built-in/op_proto/inc \ + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libproto_common + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -Dgoogle=ascend_private + +LOCAL_CPPFLAGS += -fexceptions +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_STATIC_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := libproto_common + +LOCAL_CFLAGS += -O2 -Dgoogle=ascend_private + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_STATIC_LIBRARY) + +# compile for ut/st +include $(CLEAR_VARS) +LOCAL_MODULE := libproto_common + +LOCAL_CFLAGS += -Werror -Wno-unused-variable -Dgoogle=ascend_private +LOCAL_CFLAGS += -DDAVINCI_MINI + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libascend_protobuf \ + libslog \ + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_LLT_STATIC_LIBRARY) diff --git a/metadef/graph/proto/task.proto b/metadef/graph/proto/task.proto new file mode 100644 index 00000000..d0c09840 --- /dev/null +++ b/metadef/graph/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/metadef/graph/proto_inner/ge_onnx.proto b/metadef/graph/proto_inner/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/graph/proto_inner/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/graph/ref_relation.cc b/metadef/graph/ref_relation.cc new file mode 100644 index 00000000..2f2f8a8c --- /dev/null +++ b/metadef/graph/ref_relation.cc @@ -0,0 +1,484 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ref_relation.h" + +#include +#include + +#include "utils/mem_utils.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "debug/ge_attr_define.h" +#include "graph/ge_error_codes.h" +#include "graph/utils/graph_utils.h" +#include "framework/common/debug/ge_log.h" + +using namespace std; +using namespace ge; +namespace ge { +namespace { + const char *kRefIndex = "_parent_node_index"; + const string kWhile = "While"; + const string kIf = "If"; + const string kCase = "Case"; + + const uint16_t kMaxElementNum = 100; + + std::unordered_set function_op = { + kWhile, + kIf, + kCase + }; +} + +/* Impl */ +class RefRelations::Impl { +public: + graphStatus LookUpRefRelations(const RefCell &key, unordered_set &result) { + unsigned long number = static_cast(reinterpret_cast(key.node.get())); + std::string lookup_key = key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + + std::to_string(number); + auto iter = look_up_table_.find(lookup_key); + if (iter != look_up_table_.end()) { + for (auto &c : iter->second) { + result.insert(c); + } + return GRAPH_SUCCESS; + } + GELOGW("can not find any relations! key value of dest relation is %s", lookup_key.c_str()); + return GRAPH_SUCCESS; + }; + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear() { + GELOGD("Start clear boundary reflections between main graph and sub graph!"); + look_up_table_.clear(); + values_.clear(); + return GRAPH_SUCCESS; + }; +private: + graphStatus BuildLookUpTables(); + graphStatus BuildRefRelationsForBranch( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRefRelationsForWhile( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRelationsWithFuncNodeType( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + void GetDataAndNetoutputOfSubGraph( + const ge::ComputeGraph &root_graph, + vector &data_nodes, + vector &netoutput_nodes, + const std::vector &sub_graph_names, + const std::string &node_type); + + graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); + graphStatus ProcessSubgraphDataNodes( + vector &data_nodes, + vector> &classed_data_nodes); + graphStatus ProcessSubgraphNetoutput( + const vector &netoutput_nodes, + vector>> &classed_netoutput_nodes); + + std::unordered_map> look_up_table_; + std::vector>> values_; +}; + +// Node Level +graphStatus RefRelations::Impl::BuildRefRelationsForBranch( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs) { + GELOGD("Enter BuildRefRelationsForBranch!"); + + size_t ref_i = 0; + for (const auto &ref_i_data_nodes : classed_data_nodes) { + vector in_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_IN; + cell_root.in_out_idx = ref_i; + in_ref_i_all_refs.emplace_back(cell_root); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + in_ref_i_all_refs.emplace_back(cell_in); + in_ref_i_all_refs.emplace_back(cell_out); + } + node_refs.emplace_back(in_ref_i_all_refs); + ref_i++; + } + + size_t ref_o = 0; + for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { + vector out_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_OUT; + cell_root.in_out_idx = ref_o; + out_ref_i_all_refs.emplace_back(cell_root); + for (const auto &ele : ref_o_net_nodes) { + RefCell cell_netoutput_in; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + out_ref_i_all_refs.emplace_back(cell_netoutput_in); + } + node_refs.emplace_back(out_ref_i_all_refs); + ref_o++; + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildLookUpTables() { + GELOGD("start to build look up table!"); + for (size_t i = 0; i < values_.size(); i++) { + vector> &val = values_[i]; + for (const auto &ele : val) { + for (const auto &ref_cell : ele) { + string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + + std::to_string(ref_cell.in_out_idx) + + std::to_string(static_cast(reinterpret_cast(ref_cell.node.get()))); + look_up_table_[key] = ele; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelationsForWhile( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs) { + GELOGD("Enter BuildRefRelations for while op!"); + // data_nodes has been sorted + // for while, input num must be same as output num + auto input_num = root_node->GetAllInDataAnchorsSize(); + NodePtr netoutput = nullptr; + + size_t ref_i = 0; + while (ref_i < input_num) { + auto &ref_i_data_nodes = classed_data_nodes[ref_i]; + auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; + + vector ref_i_all_refs; + RefCell cell_root_i; + RefCell cell_root_o; + cell_root_i.node_name = root_node->GetName(); + cell_root_i.node = root_node; + cell_root_i.in_out = NODE_IN; + cell_root_i.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_i); + cell_root_o.node_name = root_node->GetName(); + cell_root_o.node = root_node; + cell_root_o.in_out = NODE_OUT; + cell_root_o.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_o); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + ref_i_all_refs.emplace_back(cell_in); + ref_i_all_refs.emplace_back(cell_out); + } + + for (const auto &ele : ref_i_net_nodes) { + RefCell cell_netoutput_in; + RefCell cell_netoutput_out; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + ref_i_all_refs.emplace_back(cell_netoutput_in); + netoutput = ele.first; + } + node_refs.emplace_back(ref_i_all_refs); + ref_i++; + } + /* There exist scene like the follows, it means data0 data1 netoutput 0'th + * and 1'th tensor should be the same addr. + * Data0 Data1 + * \/ + * /\ + * netoutput + */ + if (netoutput == nullptr) { + return GRAPH_SUCCESS; + } + for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); + continue; + } + if (peer_out_data_node->GetType() != DATA) { + continue; + } + auto in_data_anchor_idx = in_anchor->GetIdx(); + auto net_in_desc = + netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); + int ref_d = 0; + int ref_n = 0; + (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); + (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); + + node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); + node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); + } + + + return GRAPH_SUCCESS; +} +// build ref relations according to diff func op type +graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( + const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs) { + // data_nodes has been sorted + auto node_type = root_node->GetType(); + + auto status = GRAPH_SUCCESS; + if (node_type != kWhile) { + status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } else { + status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } + return status; +} + +void RefRelations::Impl::GetDataAndNetoutputOfSubGraph( + const ge::ComputeGraph &root_graph, + vector &data_nodes, + vector &netoutput_nodes, + const std::vector &sub_graph_names, + const std::string &node_type) { + int sub_graph_idx = 0; + for (const auto &name : sub_graph_names) { + auto sub_graph = root_graph.GetSubgraph(name); + if (sub_graph == nullptr) { + GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); + continue; + } + for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { + auto sub_graph_node_type = sub_graph_node->GetType(); + + if (sub_graph_node_type == DATA) { + data_nodes.emplace_back(sub_graph_node); + } else if (sub_graph_node_type == NETOUTPUT) { + // if while, the first subgraph must be cond subgraph. + // There is no meaning for refs ,so continue + if (node_type == kWhile && sub_graph_idx == 0) { + continue; + } + netoutput_nodes.emplace_back(sub_graph_node); + } + continue; + } + sub_graph_idx++; + } +} + +graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { + auto parent_graph_ptr = graph.GetParentGraph(); + if (parent_graph_ptr == nullptr) { + root_graph = graph; + return GRAPH_SUCCESS; + } + auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); + if (root_graph_ptr == nullptr) { + GE_LOGE("Get null root graph"); + return GRAPH_PARAM_INVALID; + } + root_graph = *root_graph_ptr; + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphDataNodes( + vector &data_nodes, + vector> &classed_data_nodes) { + GELOGD("start to process subgraph data nodes!"); + int max_ref_idx = 0; + for (const auto &e : data_nodes) { + int i; + bool is_exist = true; + is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); + if (!is_exist) { + GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", + e->GetName().c_str(), kRefIndex); + return GRAPH_FAILED; + } + max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; + } + + while (!data_nodes.empty()) { + auto data = data_nodes.back(); + data_nodes.pop_back(); + int ref_idx = 0; + (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); + if (ref_idx >= static_cast(classed_data_nodes.size())) { + return GRAPH_FAILED; + } + classed_data_nodes[ref_idx].emplace_back(data); + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( + const vector &netoutput_nodes, + vector>> &classed_netoutput_nodes) { + GELOGD("[RefRelations]Start to process subgraph netoutput!"); + for (const auto &sub_netoutput_node : netoutput_nodes) { + auto op_desc = sub_netoutput_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { + auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", + sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); + return GRAPH_FAILED; + } + int ref_o; + if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { + if (ref_o >= static_cast(classed_netoutput_nodes.size())) { + return GRAPH_FAILED; + } + classed_netoutput_nodes[ref_o].emplace_back(std::pair( + {sub_netoutput_node, static_cast(in_data_anchor->GetIdx())} + )); + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { + GELOGD("Start to build ref relations!"); + /* First Step: Get root graph */ + ge::ComputeGraph &root_graph = graph; + auto status = GetRootGraph(graph, root_graph); + if (status != GRAPH_SUCCESS) { + return status; + } + + for (const auto &node : graph.GetAllNodes()) { + auto node_type = node->GetType(); + std::vector ref_nodes; + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + continue; + } + vector data_nodes; + vector netoutput_nodes; + // Get data and netoutput of sub_graph + GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); + size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; + vector> classed_data_nodes(max_elem_num); // according to ref_idx + vector>> classed_netoutput_nodes(max_elem_num); // according to ref_idx + status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); + return status; + } + + // for netoutput + // check netoutput + // here main graph output number must be the same as every sub_graph netoutput node + // key: netoutput node_ptr , + status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "process netoutput failed!"); + return status; + } + + vector> node_refs; + status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); + return status; + } + if (!node_refs.empty()) { + values_.push_back(node_refs); + } + } + /* Seconde Step: generate map */ + status = BuildLookUpTables(); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "Build look up tables failed!"); + return status; + } + return GRAPH_SUCCESS; +} + +/* Ref Relations Interface */ +RefRelations::RefRelations() { + impl_ = MakeShared(); + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "MakeShared failed!"); + return; + } +} + +graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set &result) { + GE_CHECK_NOTNULL(impl_); + return impl_->LookUpRefRelations(key, result); +} + +graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { + GE_CHECK_NOTNULL(impl_); + return impl_->BuildRefRelations(root_graph); +} + +graphStatus RefRelations::Clear() { + GE_CHECK_NOTNULL(impl_); + return impl_->Clear(); +} +} \ No newline at end of file diff --git a/metadef/graph/runtime_inference_context.cc b/metadef/graph/runtime_inference_context.cc new file mode 100644 index 00000000..4438eaf9 --- /dev/null +++ b/metadef/graph/runtime_inference_context.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/runtime_inference_context.h" +#include "graph/utils/tensor_adapter.h" +#include +#include "framework/common/debug/ge_log.h" + +namespace ge { +std::map> RuntimeInferenceContext::contexts_; +std::mutex RuntimeInferenceContext::ctx_mu_; + +graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { + GELOGI("To create context. session id = %s", context_id.c_str()); + auto ctx = std::unique_ptr(new (std::nothrow)RuntimeInferenceContext()); + if (ctx == nullptr) { + GELOGE(GRAPH_FAILED, + "Failed to create instance of RuntimeInferenceContext. context_id = %s", + context_id.c_str()); + return GRAPH_FAILED; + } + + std::lock_guard lk(ctx_mu_); + auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); + if (!emplace_ret.second) { + GELOGE(GRAPH_FAILED, "Old context not destroyed"); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { + GELOGI("To destroy context. session id = %s", context_id.c_str()); + std::lock_guard lk(ctx_mu_); + contexts_.erase(context_id); +} + +graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { + std::lock_guard lk(ctx_mu_); + auto it = contexts_.find(context_id); + if (it != contexts_.end()) { + *ctx = it->second.get(); + return GRAPH_SUCCESS; + } + + GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); + return GRAPH_FAILED; +} + + +graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { + std::lock_guard lk(mu_); + auto &output_tensors = tensors_[node_id]; + if (static_cast(output_id) >= output_tensors.size()) { + output_tensors.resize(output_id + 1); + } + + GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); + output_tensors[output_id] = std::move(tensor); + + auto &output_ge_tensors = ge_tensors_[node_id]; + if (static_cast(output_id) >= output_ge_tensors.size()) { + output_ge_tensors.resize(output_id + 1); + } + + GELOGD("Set ge tensor for node_id = %ld, output_id = %d", node_id, output_id); + output_ge_tensors[output_id] = TensorAdapter::AsGeTensorPtr(tensor); + return GRAPH_SUCCESS; +} + +graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { + if (output_id < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); + return GRAPH_PARAM_INVALID; + } + + std::lock_guard lk(mu_); + auto iter = tensors_.find(node_id); + if (iter == tensors_.end()) { + GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); + return INTERNAL_ERROR; + } + + auto &output_tensors = iter->second; + if (static_cast(output_id) >= output_tensors.size()) { + GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); + return GRAPH_FAILED; + } + + GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); + tensor = output_tensors[output_id]; + return GRAPH_SUCCESS; +} + +graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor) { + if (output_id < 0) { + GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); + return GRAPH_PARAM_INVALID; + } + + std::lock_guard lk(mu_); + auto iter = ge_tensors_.find(node_id); + if (iter == ge_tensors_.end()) { + GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); + return INTERNAL_ERROR; + } + + auto &output_tensors = iter->second; + if (static_cast(output_id) >= output_tensors.size()) { + GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); + return GRAPH_FAILED; + } + + GELOGD("Get ge tensor for node_id = %ld, output_id = %d", node_id, output_id); + tensor = output_tensors[output_id]; + return GRAPH_SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/metadef/graph/shape_refiner.cc b/metadef/graph/shape_refiner.cc new file mode 100644 index 00000000..99eda1a5 --- /dev/null +++ b/metadef/graph/shape_refiner.cc @@ -0,0 +1,788 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/shape_refiner.h" + +#include +#include +#include +#include +#include +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" + +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "external/graph/operator.h" +#include "external/graph/operator_factory.h" +#include "framework/common/debug/ge_log.h" +#include "graph/compute_graph.h" +#include "graph/operator_factory_impl.h" +#include "utils/node_utils.h" +#include "utils/op_desc_utils.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace ge { +namespace { +const uint32_t kWhileBodySubGraphIdx = 1; + +graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { + GELOGD("Enter reverse brush while body subgraph process!"); + auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); + if (sub_graph_body == nullptr) { + GELOGE(GRAPH_FAILED, "Get while body graph failed!"); + return GRAPH_FAILED; + } + + for (const auto &node_sub : sub_graph_body->GetAllNodes()) { + // const/constant/variale etc. no need to reverse brush + if (node_sub->GetInDataNodes().empty() && node_sub->GetType() != DATA) { + continue; + } + for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { + auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); + GE_IF_BOOL_EXEC(input_desc == nullptr, + GELOGW("Get null input by index %zu from node %s ", + i, node_sub->GetName().c_str()); + continue); + (void)input_desc->SetUnknownDimNumShape(); + } + for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { + auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); + (void)output_desc->SetUnknownDimNumShape(); + } + } + + for (const auto &node_sub : sub_graph_body->GetAllNodes()) { + if (!node_sub->GetInDataNodes().empty() || node_sub->GetType() == DATA) { + continue; + } + for (const auto &out_data_anchor : node_sub->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_data_anchor); + auto out_data_anchor_idx = out_data_anchor->GetIdx(); + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_data_node); + GE_CHECK_NOTNULL(peer_in_data_node->GetOpDesc()); + int idx = peer_in_data_anchor->GetIdx(); + auto shape = node_sub->GetOpDesc()->MutableOutputDesc(out_data_anchor_idx)->GetShape(); + peer_in_data_node->GetOpDesc()->MutableInputDesc(idx)->SetShape(shape); + } + } + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + // check sub_graph shape. Get max for update. + for (size_t i = 0; i < ref_out_tensors.size(); ++i) { + if (ref_out_tensors[i].empty()) { + continue; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = ref_out_tensors[i].at(0); + for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { + auto &tensor = ref_out_tensors[i].at(j); + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto shape = tensor.MutableShape(); + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (INT64_MAX / dim < size) { + GELOGE(PARAM_INVALID, "The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class branch op process"); + if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { + return UpdataOutputForMultiBatcch(node, ref_out_tensors); + } + + // check sub_graph shape.If not same ,do unknown shape process + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (auto &tensor : ref_out_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", + node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", + node->GetName().c_str(), i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, + std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class while op process"); + if (ref_data_tensors.size() != ref_out_tensors.size()) { + GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", + node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size()); + return GRAPH_FAILED; + } + for (size_t i = 0; i < ref_data_tensors.size(); i++) { + if (ref_out_tensors[i].size() != 1) { + GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); + return GRAPH_FAILED; + } + } + bool is_need_reverse_brush = false; + // check input and output + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + auto tmp_shape = ref_out_tensor.MutableShape(); + // ref_i's data and output tensor shape should be same + for (auto &tensor : ref_data_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims() != tmp_shape.GetDims()) { + ref_out_tensor.SetUnknownDimNumShape(); + is_need_reverse_brush = true; + break; + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + // reverse refresh while body shape + if (is_need_reverse_brush) { + return ReverseBrushWhileBodySubGraph(node); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } + + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + for (const auto &node_sub : sub_graph->GetDirectNode()) { + if (node_sub->GetType() != DATA) { + continue; + } + int ref_i; + auto data_opdesc = node_sub->GetOpDesc(); + if (data_opdesc == nullptr) { + GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", + name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", + name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + continue; + } + auto input_desc = op_desc->MutableInputDesc(ref_i); + if (input_desc == nullptr) { + GE_LOGE("The ref index(%d) on the data %s on the sub graph %s " + "parent node %s are incompatible, inputs num %u", + ref_i, node_sub->GetName().c_str(), name.c_str(), + node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); + return GRAPH_FAILED; + } + GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), + node->GetName().c_str()); + auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); + + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + return ret; + } + ret = data_opdesc->UpdateOutputDesc(0, *input_desc); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", + node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); + return ret; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, + NodePtr &netoutput, const ConstNodePtr &node, + std::vector> &ref_data_tensors) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + } + if (sub_node->GetType() == DATA) { + if (sub_node->GetOpDesc() == nullptr) { + return GRAPH_FAILED; + } + + int ref_i; + if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + return GRAPH_FAILED; + } + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { + GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", + sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); + } + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } + + std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + + for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } + auto sub_graph = root_graph->GetSubgraph(name); + if (sub_graph == nullptr) { + GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + NodePtr netoutput = nullptr; + auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); + if (ret != GRAPH_SUCCESS) { + return ret; + } + if (netoutput == nullptr) { + GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + auto netoutput_opdesc = netoutput->GetOpDesc(); + if (netoutput_opdesc == nullptr) { + GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", + name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { + auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); + if (edge_desc == nullptr) { + GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", + name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); + return GRAPH_FAILED; + } + GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", + edge_anchor->GetIdx(), edge_desc->GetShape().GetDimNum()); + int ref_i; + if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. + continue; + } + GELOGI("Parent node index of edge desc is %d", ref_i); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { + return GRAPH_FAILED; + } + ref_out_tensors[ref_i].emplace_back(*edge_desc); + } + } + + if (node->GetType() == WHILE) { + return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); + } + return UpdateParentNodeForBranch(node, ref_out_tensors); +} + +string Serial(const vector &dims) { + string serial_string; + serial_string += "["; + for (int64_t dim : dims) { + serial_string += std::to_string(dim) + " "; + } + serial_string += "]"; + return serial_string; +} + +void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { + desc_str += "["; + std::vector> shape_range; + (void)desc->GetShapeRange(shape_range); + for (const auto &pair : shape_range) { + desc_str += "{"; + desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); + desc_str += "},"; + } + desc_str += "] "; +} + +void SerialShapeAndDtype(const GeTensorDescPtr &desc, bool is_origin_info, std::string &desc_str) { + desc_str += "["; + if (!is_origin_info) { + for (int64_t dim : desc->GetShape().GetDims()) { + desc_str += std::to_string(dim) + " "; + } + desc_str += "]"; + desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetDataType()) + ":" + + TypeUtils::FormatToSerialString(desc->GetFormat()) + " "; + } else { + for (int64_t dim : desc->GetOriginShape().GetDims()) { + desc_str += std::to_string(dim) + " "; + } + desc_str += "]"; + desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetOriginDataType()) + ":" + + TypeUtils::FormatToSerialString(desc->GetOriginFormat()) + " "; + } +} + +graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { + auto in_idx = in_anchor->GetIdx(); + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + continue; + } + int peer_out_idx = peer_out_data_anchor->GetIdx(); + auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); + + // check shape and dtype continuity. do not stop process + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); + if (in_desc == nullptr) { + continue; + } + auto in_shape = in_desc->MutableShape().GetDims(); + auto in_dtype = in_desc->GetDataType(); + auto peer_out_shape = peer_out_desc->MutableShape().GetDims(); + auto peer_out_dtype = peer_out_desc->GetDataType(); + if (peer_out_dtype != in_dtype) { + GELOGW("current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " + "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), + peer_out_data_node->GetName().c_str(), peer_out_idx, + TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); + } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { + string in_shape_str = Serial(in_shape); + string peer_out_shape_str = Serial(peer_out_shape); + GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), + peer_out_data_node->GetName().c_str(), peer_out_idx, peer_out_shape_str.c_str()); + } + // refresh current node input desc + in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); + in_desc->SetShape(peer_out_desc->MutableShape()); + in_desc->SetDataType(peer_out_desc->GetDataType()); + in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); + if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) { + std::vector> shape_range; + (void) peer_out_desc->GetShapeRange(shape_range); + in_desc->SetShapeRange(shape_range); + } + ge::TensorUtils::SetRealDimCnt(*in_desc, + static_cast(peer_out_desc->MutableShape().GetDims().size())); + } + return GRAPH_SUCCESS; +} +} // namespace +void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return; + } + ge::OpDescPtr op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return); + std::string str; + if (op_desc->GetInputsSize() != 0) { + std::string input_desc_str = "input shape: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + SerialShapeAndDtype(input_desc, false, input_desc_str); + } + str += input_desc_str; + + input_desc_str = "input origin shape: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + SerialShapeAndDtype(input_desc, true, input_desc_str); + } + str += input_desc_str; + + input_desc_str = "input shape range: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + SerialShapeRange(input_desc, input_desc_str); + } + str += input_desc_str; + } + + if (op_desc->GetAllOutputsDescSize() != 0) { + std::string output_desc_str = "output shape: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + continue; + } + SerialShapeAndDtype(output_desc, false, output_desc_str); + } + str += output_desc_str; + + output_desc_str = "output origin shape: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + continue; + } + SerialShapeAndDtype(output_desc, true, output_desc_str); + } + str += output_desc_str; + + output_desc_str = "output shape range: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + SerialShapeRange(output_desc, output_desc_str); + } + str += output_desc_str; + } + GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); +} + +InferenceContextPtr CreateInferenceContext(const std::unordered_map &context_map, + const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return nullptr; + } + InferenceContextPtr inference_context = std::shared_ptr(InferenceContext::Create()); + if (inference_context == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); + return nullptr; + } + + auto all_in_data_anchors = node->GetAllInDataAnchors(); + std::vector> input_shapes_and_types(all_in_data_anchors.size()); + std::vector marks; + + bool has_input_shapes_and_types = false; + for (const auto &in_anchor : all_in_data_anchors) { + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + continue; + } + + auto input_node = out_anchor->GetOwnerNode(); + if (input_node == nullptr) { + continue; + } + + auto iter = context_map.find(input_node); + if (iter != context_map.end()) { + const auto &src_context = iter->second; + GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr); + GELOGD("node:%s get %ld marks from node:%s", + node->GetName().c_str(), src_context->GetMarks().size(), input_node->GetName().c_str()); + for (auto mark : src_context->GetMarks()) { + marks.push_back(mark); + } + auto output_idx = out_anchor->GetIdx(); + auto input_idx = in_anchor->GetIdx(); + auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); + if (output_idx < static_cast(output_shape_and_type.size())) { + GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, + node->GetName().c_str(), input_idx); + input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; + has_input_shapes_and_types = true; + } else { + GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, + output_shape_and_type.size()); + } + } + } + + if (has_input_shapes_and_types) { + inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); + } + inference_context->SetMarks(marks); + + return inference_context; +} + +namespace { +thread_local std::unordered_map context_map; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +void ShapeRefiner::ClearContextMap() { + context_map.clear(); +} + +graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { + return InferShapeAndType(node, op, true); +} + +graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + + graphStatus ret; + if (before_subgraph) { + ret = UpdateSubGraphDataNodes(node); + if (ret != GRAPH_SUCCESS) { + return ret; + } + } + // Get infer func and execute + ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + // Op ir no infer func, try to get infer func from operator factory + auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); + if (node_op.IsEmpty()) { + GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); + return ret; + } + + GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); + auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); + node_op.BreakConnect(); + if (temp_op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "temp op desc is null"); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("InferShapeAndType UpdateInputName failed"); + for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { + if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { + break; + } + return GRAPH_SUCCESS; + } + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("InferShapeAndType UpdateOutputName failed"); + } + op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); + } + if (ret != GRAPH_SUCCESS) { + return ret; + } + + if (!before_subgraph) { + return UpdateParentNodeOutTensor(node); + } + return GRAPH_SUCCESS; +} + +graphStatus ShapeRefiner::InferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, bool before_subgraph) { + auto op_desc = node->GetOpDesc(); + const auto &op_type = op_desc->GetType(); + + graphStatus ret; + if (before_subgraph) { + ret = UpdateSubGraphDataNodes(node); + if (ret != GRAPH_SUCCESS) { + return ret; + } + } + // Get infer func and execute + ret = op_desc->CallInferFunc(op); + if (ret == GRAPH_PARAM_INVALID) { + GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str()); + auto origin_type = NodeUtils::GetNodeType(*node); + auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type); + if (infer_func == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to Get InferFunc.type is %s", origin_type.c_str()); + return GRAPH_FAILED; + } + op_desc->AddInferFunc(infer_func); + ret = op_desc->CallInferFunc(op); + GELOGI("op CallInferFunc second. ret: %u", ret); + } + if (ret != GRAPH_SUCCESS) { + return ret; + } + + if (!before_subgraph) { + return UpdateParentNodeOutTensor(node); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { + return InferShapeAndType(node, true); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus ShapeRefiner::InferShapeAndTypeForRunning(const NodePtr &node, bool before_subgraph) { + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + auto opdesc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + + PrintInOutTensorShape(node, "before_infershape when running"); + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + graphStatus status = InferShapeAndTypeForRunning(node, op, before_subgraph); + if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { + PrintInOutTensorShape(node, "after_infershape when running"); + return GRAPH_SUCCESS; + } else { + GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, bool before_subgraph) { + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + auto opdesc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); + // some op can not infershape twice such as aipp + bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); + if (need_update_input) { + auto status = UpdateOpInputDesc(node); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "update op input_desc failed!"); + return status; + } + } + + if (node->Verify() != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + PrintInOutTensorShape(node, "before_infershape"); + Operator op = OpDescUtils::CreateOperatorFromNode(node); + + if (!is_unknown_graph) { + auto inference_context = CreateInferenceContext(context_map, node); + GE_CHECK_NOTNULL(inference_context); + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + op.SetInferenceContext(inference_context); + } + + graphStatus status = InferShapeAndType(node, op, before_subgraph); + if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { + if (is_unknown_graph) { + PrintInOutTensorShape(node, "after_infershape when running"); + return GRAPH_SUCCESS; + } + auto op_desc = node->GetOpDesc(); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + if (output_tensor->MutableShape().GetDims().empty()) { + output_tensor->SetOriginShape(output_tensor->GetShape()); + } + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() + .size())); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + } + } else { + GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + ctx_after_infer->GetMarks().size()); + (void)context_map.emplace(node, ctx_after_infer); + } + } + } + PrintInOutTensorShape(node, "after_infershape"); + + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/stub/Makefile b/metadef/graph/stub/Makefile new file mode 100644 index 00000000..f339fa33 --- /dev/null +++ b/metadef/graph/stub/Makefile @@ -0,0 +1,6 @@ +inc_path := $(shell pwd)/metadef/inc/external/ +out_path := $(shell pwd)/out/graph/lib64/stub/ +stub_path := $(shell pwd)/metadef/graph/stub/ + +mkdir_stub := $(shell mkdir -p $(out_path)) +graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/metadef/graph/stub/gen_stubapi.py b/metadef/graph/stub/gen_stubapi.py new file mode 100644 index 00000000..05b45a23 --- /dev/null +++ b/metadef/graph/stub/gen_stubapi.py @@ -0,0 +1,585 @@ +#!/usr/bin/python3.7 +# -*- coding: UTF-8 -*- +#------------------------------------------------------------------- +# Purpose: +# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved. +#------------------------------------------------------------------- + +import os +import re +import sys +import logging + +logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', + level=logging.INFO) + +""" + this attr is used for symbol table visible +""" +GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' + +""" + generate stub func body by return type +""" +RETURN_STATEMENTS = { + 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n ' + ' << "environment variables and compilation options to make sure you use the correct library."\n' + ' << std::endl;\n' + ' return ACL_ERROR_COMPILING_STUB_MODE;', + 'Status': ' return SUCCESS;', + 'Graph': ' return Graph();', + 'Graph&': ' return *this;', + 'Format': ' return Format();', + 'Format&': ' return *this;', + 'Shape': ' return Shape();', + 'Shape&': ' return *this;', + 'TensorDesc': ' return TensorDesc();', + 'TensorDesc&': ' return *this;', + 'Tensor': ' return Tensor();', + 'Tensor&': ' return *this;', + 'Operator': ' return Operator();', + 'Operator&': ' return *this;', + 'Ptr': ' return nullptr;', + 'std::string': ' return "";', + 'std::string&': ' return "";', + 'string': ' return "";', + 'int': ' return 0;', + 'DataType': ' return DT_FLOAT;', + 'InferenceContextPtr': ' return nullptr;', + 'SubgraphBuilder': ' return nullptr;', + 'OperatorImplPtr': ' return nullptr;', + 'OutHandler': ' return nullptr;', + 'std::vector': ' return {};', + 'std::vector': ' return {};', + 'std::map': ' return {};', + 'uint32_t': ' return 0;', + 'int64_t': ' return 0;', + 'uint64_t': ' return 0;', + 'size_t': ' return 0;', + 'float': ' return 0.0f;', + 'bool': ' return false;', +} + +""" + max code len per line in hua_wei software programming specifications +""" +max_code_len_per_line = 100 + +""" + white_list_for_debug, include_dir_key_words is to + determines which header files to generate cc files from + when DEBUG on +""" +white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", "inference_context.h", + "ge_ir_build.h", "ge_api.h", "ascend_string.h", "gnode.h"] +include_dir_key_words = ["ge", "graph"] +DEBUG = True + + +def need_generate_func(func_line): + """ + :param func_line: + :return: + """ + if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ + or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): + return False + return True + + +def file_endswith_white_list_suffix(file): + """ + :param file: + :return: + """ + if DEBUG: + for suffix in white_list_for_debug: + if file.endswith(suffix): + return True + return False + else: + return True + + +""" + belows are patterns used for analyse .h file +""" +# pattern function +pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after +([a-zA-Z~_] # void int likely +.* +[)] #we find ) +(?!.*{) # we do not want the case int abc() const +.*) +(;.*) #we want to find ; and after for we will replace these later +\n$ +""", re.VERBOSE | re.MULTILINE | re.DOTALL) + +# pattern comment +pattern_comment = re.compile(r'^\s*//') +pattern_comment_2_start = re.compile(r'^\s*/[*]') +pattern_comment_2_end = re.compile(r'[*]/\s*$') +# pattern define +pattern_define = re.compile(r'^\s*#define') +pattern_define_return = re.compile(r'\\\s*$') +# blank line +pattern_blank_line = re.compile(r'^\s*$') +# virtual,explicit,friend,static +pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') +# lead space +pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') +# functions will have patterns such as func ( or func( +# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist +# format like :"operator = ()" +pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') +# template +pattern_template = re.compile(r'^\s*template') +pattern_template_end = re.compile(r'>\s*$') +# namespace +pattern_namespace = re.compile(r'namespace.*{') +# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with +pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: + line, func_name = self.handle_class_member_func(line, template_string) + # Normal functions + else: + line, func_name = self.handle_normal_func(line, template_string) + + need_generate = need_generate_func(line) + # func body + line += self.implement_function(line) + # comment + line = self.gen_comment(start_i) + line + # write to out file + self.write_func_content(line, func_name, need_generate) + # next loop + self.line_index += 1 + + logging.info('Added %s functions', len(self.func_list_exist)) + logging.info('Successfully converted,please see ' + self.output_file) + + def handle_func1(self, line): + """ + :param line: + :return: + """ + find1 = re.search('[(]', line) + if not find1: + self.line_index += 1 + return "continue", line, None + find2 = re.search('[)]', line) + start_i = self.line_index + space_match = pattern_leading_space.search(line) + # deal with + # int abc(int a, + # int b) + if find1 and (not find2): + self.line_index += 1 + line2 = self.input_content[self.line_index] + if space_match: + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): + self.line_index += 1 + line2 = self.input_content[self.line_index] + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + + match_start = pattern_start.search(self.input_content[self.line_index]) + match_end = pattern_end.search(self.input_content[self.line_index]) + if match_start: # like ) { or ) {} int the last line + if not match_end: + self.stack.append('normal_now') + ii = start_i + while ii <= self.line_index: + ii += 1 + self.line_index += 1 + return "continue", line, start_i + logging.info("line[%s]", line) + # ' int abc();'->'int abc()' + (line, match) = pattern_func.subn(r'\2\n', line) + logging.info("line[%s]", line) + # deal with case: + # 'int \n abc(int a, int b)' + if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): + line = self.input_content[start_i - 1] + line + line = line.lstrip() + if not match: + self.line_index += 1 + return "continue", line, start_i + return "pass", line, start_i + + def handle_stack(self, match_start): + """ + :param match_start: + :return: + """ + line = self.input_content[self.line_index] + match_end = pattern_end.search(line) + if match_start: + self.stack.append('normal_now') + if match_end: + top_status = self.stack.pop() + if top_status == 'namespace_now': + self.output_fd.write(line + '\n') + elif top_status == 'class_now': + self.stack_class.pop() + self.stack_template.pop() + if match_start or match_end: + self.line_index += 1 + return "continue" + + if len(self.stack) > 0 and self.stack[-1] == 'normal_now': + self.line_index += 1 + return "continue" + return "pass" + + def handle_class(self, template_string, line, match_start, match_class): + """ + :param template_string: + :param line: + :param match_start: + :param match_class: + :return: + """ + if match_class: # we face a class + self.stack_template.append(template_string) + self.stack.append('class_now') + class_name = match_class.group(3) + + # class template specializations: class A > + if '<' in class_name: + k = line.index('<') + fit = 1 + for ii in range(k + 1, len(line)): + if line[ii] == '<': + fit += 1 + if line[ii] == '>': + fit -= 1 + if fit == 0: + break + class_name += line[k + 1:ii + 1] + logging.info('class_name[%s]', class_name) + self.stack_class.append(class_name) + while not match_start: + self.line_index += 1 + line = self.input_content[self.line_index] + match_start = pattern_start.search(line) + self.line_index += 1 + return "continue" + return "pass" + + def handle_template(self): + line = self.input_content[self.line_index] + match_template = pattern_template.search(line) + template_string = '' + if match_template: + match_template_end = pattern_template_end.search(line) + template_string = line + while not match_template_end: + self.line_index += 1 + line = self.input_content[self.line_index] + template_string += line + match_template_end = pattern_template_end.search(line) + self.line_index += 1 + return template_string + + def handle_namespace(self): + line = self.input_content[self.line_index] + match_namespace = pattern_namespace.search(line) + if match_namespace: # we face namespace + self.output_fd.write(line + '\n') + self.stack.append('namespace_now') + self.line_index += 1 + + def handle_normal_func(self, line, template_string): + template_line = '' + self.stack_template.append(template_string) + if self.stack_template[-1] != '': + template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) + # change '< class T = a, class U = A(3)>' to '' + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = re.sub(r'\s*=.*,', ',', line) + line = re.sub(r'\s*=.*\)', ')', line) + line = template_line + line + self.stack_template.pop() + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def handle_class_member_func(self, line, template_string): + template_line = '' + x = '' + if template_string != '': + template_string = re.sub(r'\s*template', 'template', template_string) + template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) + template_string = re.sub(r'\s*=.*,', ',', template_string) + template_string = re.sub(r'\s*=.*', '', template_string) + if self.stack_template[-1] != '': + if not (re.search(r'<\s*>', stack_template[-1])): + template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) + if not (re.search(r'<.*>', self.stack_class[-1])): + # for x we get like template -> + x = re.sub(r'template\s*<', '<', template_line) # remove template -> + x = re.sub(r'\n', '', x) + x = re.sub(r'\s*=.*,', ',', x) + x = re.sub(r'\s*=.*\>', '>', x) + x = x.rstrip() # remove \n + x = re.sub(r'(class|typename)\s+|(|\s*class)', '', + x) # remove class,typename -> + x = re.sub(r'<\s+', '<', x) + x = re.sub(r'\s+>', '>', x) + x = re.sub(r'\s+,', ',', x) + x = re.sub(r',\s+', ', ', x) + line = re.sub(r'\s*=\s+0', '', line) + line = re.sub(r'\s*=\s+.*,', ',', line) + line = re.sub(r'\s*=\s+.*\)', ')', line) + logging.info("x[%s]\nline[%s]", x, line) + # if the function is long, void ABC::foo() + # breaks into two lines void ABC::\n foo() + temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) + if len(temp_line) > max_code_len_per_line: + line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) + else: + line = temp_line + logging.info("line[%s]", line) + # add template as the above if there is one + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = template_line + template_string + line + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) + + def gen_comment(self, start_i): + comment_line = '' + # Function comments are on top of function declarations, copy them over + k = start_i - 1 # one line before this func start + if pattern_template.search(self.input_content[k]): + k -= 1 + if pattern_comment_2_end.search(self.input_content[k]): + comment_line = self.input_content[k].lstrip() + while not pattern_comment_2_start.search(self.input_content[k]): + k -= 1 + comment_line = self.input_content[k].lstrip() + comment_line + else: + for j in range(k, 0, -1): + c_line = self.input_content[j] + if pattern_comment.search(c_line): + c_line = re.sub(r'\s*//', '//', c_line) + comment_line = c_line + comment_line + else: + break + return comment_line + + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + + +def collect_header_files(path): + """ + :param path: + :return: + """ + header_files = [] + shared_includes_content = [] + for root, dirs, files in os.walk(path): + files.sort() + for file in files: + if file.find("git") >= 0: + continue + if not file.endswith('.h'): + continue + file_path = os.path.join(root, file) + file_path = file_path.replace('\\', '/') + header_files.append(file_path) + include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) + shared_includes_content.append(include_str) + # for acl error code + shared_includes_content.append('#include \n') + shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n') + return header_files, shared_includes_content + + +def generate_stub_file(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + target_header_files, shared_includes_content = collect_header_files(inc_dir) + for header_file in target_header_files: + if not file_endswith_white_list_suffix(header_file): + continue + cc_file = re.sub('.h*$', '.cc', header_file) + h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) + h_2_cc.h2cc() + + +def gen_code(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + if not inc_dir.endswith('/'): + inc_dir += '/' + if not out_cc_dir.endswith('/'): + out_cc_dir += '/' + for include_dir_key_word in include_dir_key_words: + generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) + + +if __name__ == '__main__': + inc_dir = sys.argv[1] + out_cc_dir = sys.argv[2] + gen_code(inc_dir, out_cc_dir) diff --git a/metadef/graph/tensor.cc b/metadef/graph/tensor.cc new file mode 100644 index 00000000..d7668395 --- /dev/null +++ b/metadef/graph/tensor.cc @@ -0,0 +1,772 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "external/graph/tensor.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_tensor.h" +#include "securec.h" +#include "utils/attr_utils.h" +#include "utils/tensor_adapter.h" +#include "utils/tensor_utils.h" +#include "utils/type_utils.h" + +namespace { +/// Extra 8 bytes store pointer of string +/// Extra 1 byte store '\0' +const int EXTRA_STORE_POINTER_FOR_STRING = 8; +const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; +const int64_t UNKNOWN_DIM_SIZE = -1; +} // namespace + +namespace ge { +// If not overflow return true +static bool Int64MulNotOverflow(int64_t a, int64_t b) { + if (a > 0) { + if (b > 0) { + if (a > (INT64_MAX / b)) { + return false; + } + } else { + if (b < (INT64_MIN / a)) { + return false; + } + } + } else { + if (b > 0) { + if (a < (INT64_MIN / b)) { + return false; + } + } else { + if ((a != 0) && (b < (INT64_MAX / a))) { + return false; + } + } + } + return true; +} + +class TensorDescImpl { + public: + TensorDescImpl() = default; + ~TensorDescImpl() = default; + TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} + + Shape shape_; + std::vector> range_; + Format format_ = FORMAT_ND; + Format origin_format_ = FORMAT_ND; + DataType data_type_ = DT_FLOAT; + Shape origin_shape_; + int64_t size_ = 0; + int64_t real_dim_cnt_ = 0; + std::string name_; +}; + +class TensorImpl { + public: + TensorImpl() = default; + ~TensorImpl() = default; + + explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} + TensorImpl(const TensorDesc &tensor_desc, const std::vector &data) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} + TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} + TensorImpl(TensorDesc &&tensor_desc, std::vector &&data) + : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} + + graphStatus SetData(const std::string &data) { + if (!data.empty()) { + /// Extra 8 bytes store pointer of string + /// Extra 1 byte store '\0' + size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL; + std::unique_ptr buff(new (std::nothrow) char[total_size]()); + if (buff == nullptr) { + GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); + return GRAPH_FAILED; + } + uint64_t *p = reinterpret_cast(buff.get()); + // Front 8 bytes store pointer of string + char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING; + p[0] = reinterpret_cast(raw_data); + int32_t memcpy_ret = memcpy_s(raw_data, + total_size - EXTRA_STORE_POINTER_FOR_STRING, + data.c_str(), + data.size() + 1); + GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); + (void)ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; + } + + graphStatus SetData(const std::vector &data) { + if (data.empty()) { + GELOGE(GRAPH_FAILED, "there is no data, please check the input variable"); + return GRAPH_FAILED; + } + size_t total_size = 0; + for (auto str : data) { + /// Extra 8 bytes store pointer of each string + /// Extra 1 byte store '\0' + total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL); + } + std::unique_ptr buff(new (std::nothrow) char[total_size]); + if (buff == nullptr) { + GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); + return GRAPH_FAILED; + } + uint64_t *p = reinterpret_cast(buff.get()); + // Front some bytes store pointer of each string + char *raw_data = buff.get() + data.size() * sizeof(uint64_t); + uint64_t ptr_size = data.size() * sizeof(uint64_t); + for (size_t i = 0; i < data.size(); ++i) { + p[i] = reinterpret_cast(raw_data); + if (total_size < ptr_size) { + GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size); + return GRAPH_FAILED; + } + int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1); + GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); + raw_data += (data[i].size() + 1); + ptr_size += (data[i].size() + 1); + } + + (void)ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); + return GRAPH_SUCCESS; + } + + GeTensor ge_tensor; +}; + +class ShapeImpl { + public: + ShapeImpl() = default; + ~ShapeImpl() = default; + explicit ShapeImpl(const std::vector &dims) { + bool is_unknown_dim_num = false; + for (const auto &dim : dims) { + if (dim == UNKNOWN_DIM_NUM) { + is_unknown_dim_num = true; + break; + } + } + dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; + } + + std::vector dims_; +}; + +Shape::Shape() { impl_ = ComGraphMakeShared(); } + +Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShared(dims); } + +size_t Shape::GetDimNum() const { + if (impl_ != nullptr) { + for (auto i : impl_->dims_) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } + return impl_->dims_.size(); + } + return 0; +} + +int64_t Shape::GetDim(size_t idx) const { + if (impl_ != nullptr) { + if (idx >= impl_->dims_.size()) { + return 0; + } + return impl_->dims_[idx]; + } + return 0; +} + +graphStatus Shape::SetDim(size_t idx, int64_t value) { + if (impl_ != nullptr) { + if (idx >= impl_->dims_.size()) { + return GRAPH_FAILED; + } + impl_->dims_[idx] = value; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +std::vector Shape::GetDims() const { + vector dims; + if (impl_ != nullptr) { + return impl_->dims_; + } + return dims; +} + +int64_t Shape::GetShapeSize() const { + if (impl_ != nullptr) { + if (impl_->dims_.empty()) { + return 0; + } + int64_t size = 1; + for (auto i : impl_->dims_) { + if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { + return UNKNOWN_DIM_SIZE; + } + + if (!Int64MulNotOverflow(size, i)) { + GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); + size = 0; + return size; + } + size *= i; + } + return size; + } + return 0; +} + +TensorDesc::TensorDesc() { + impl = ComGraphMakeShared(); // lint !e665 +} + +TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { + impl = ComGraphMakeShared(shape, format, dt); // lint !e665 + SetRealDimCnt(shape.GetDimNum()); +} + +TensorDesc::TensorDesc(const TensorDesc &desc) { + // Copy + impl = ComGraphMakeShared(); // lint !e665 + if (desc.impl != nullptr && impl != nullptr) { + *impl = *desc.impl; + } +} + +TensorDesc::TensorDesc(TensorDesc &&desc) { + // Move + impl = std::move(desc.impl); +} + +TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { + // Copy + if (&desc != this) { + impl = ComGraphMakeShared(); + if (desc.impl != nullptr && impl != nullptr) { + *impl = *desc.impl; + } + } + return *this; +} + +TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { + if (&desc != this) { + impl = std::move(desc.impl); + } + return *this; +} + +void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { + if (impl != nullptr) { + impl->shape_ = shape; + impl->format_ = format; + impl->data_type_ = dt; + } +} + +Shape TensorDesc::GetShape() const { + if (impl != nullptr) { + return impl->shape_; + } + return Shape(); +} + +void TensorDesc::SetShape(const Shape &shape) { + if (impl != nullptr) { + impl->shape_ = shape; + } +} + +// set shape with -2, it stand for unknown shape +graphStatus TensorDesc::SetUnknownDimNumShape() { + if (impl != nullptr) { + impl->shape_ = Shape({UNKNOWN_DIM_NUM}); + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); + return GRAPH_FAILED; +} + +// for unknown shape +graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { + if (impl != nullptr) { + impl->range_ = range; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); + return GRAPH_FAILED; +} +graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { + if (impl != nullptr) { + range = impl->range_; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "impl is nullptr!"); + return GRAPH_FAILED; +} + +Shape TensorDesc::GetOriginShape() const { + if (impl != nullptr) { + return impl->origin_shape_; + } + return Shape(); +} + +void TensorDesc::SetOriginShape(const Shape &origin_shape) { + if (impl != nullptr) { + impl->origin_shape_ = origin_shape; + } +} + +Format TensorDesc::GetFormat() const { + if (impl != nullptr) { + return impl->format_; + } + return FORMAT_RESERVED; +} + +void TensorDesc::SetFormat(Format format) { + if (impl != nullptr) { + impl->format_ = format; + } +} + +Format TensorDesc::GetOriginFormat() const { + if (impl != nullptr) { + return impl->origin_format_; + } + return FORMAT_RESERVED; +} + +void TensorDesc::SetOriginFormat(Format origin_format) { + if (impl != nullptr) { + impl->origin_format_ = origin_format; + } +} + +DataType TensorDesc::GetDataType() const { + if (impl != nullptr) { + return impl->data_type_; + } + return DT_UNDEFINED; +} + +void TensorDesc::SetDataType(DataType dt) { + if (impl != nullptr) { + impl->data_type_ = dt; + } +} + +void TensorDesc::SetSize(int64_t size) { + if (impl != nullptr) { + impl->size_ = size; + } +} + +int64_t TensorDesc::GetSize() const { + if (impl != nullptr) { + return impl->size_; + } + return 0; +} + +void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { + if (impl != nullptr) { + impl->real_dim_cnt_ = real_dim_cnt; + } +} + +int64_t TensorDesc::GetRealDimCnt() const { + if (impl != nullptr) { + return impl->real_dim_cnt_; + } + return 0; +} + +std::string TensorDesc::GetName() const { + if (impl != nullptr) { + return impl->name_; + } + return ""; +} + +void TensorDesc::SetName(const std::string &name) { + if (impl != nullptr) { + impl->name_ = name; + } +} + +graphStatus TensorDesc::GetName(AscendString &name) { + if (impl != nullptr) { + name = AscendString(impl->name_.c_str()); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +void TensorDesc::SetName(const char *name) { + if (impl != nullptr && name != nullptr) { + impl->name_ = name; + } +} + +Tensor::Tensor() { impl = ComGraphMakeShared(); } + +Tensor::Tensor(const TensorDesc &tensor_desc) { + impl = ComGraphMakeShared(tensor_desc); // lint !e665 +} + +Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + + auto data_size = data.size(); + if (ret && (shape_size || (data_size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + impl = ComGraphMakeShared(tensor_desc, data); // lint !e665 +} + +Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + if (ret && (shape_size || (size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + + impl = ComGraphMakeShared(tensor_desc, data, size); // lint !e665 +} + +Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { + uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); + DataType data_type = tensor_desc.GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + } + + auto data_size = data.size(); + if (ret && (shape_size || (data_size != type_length))) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + } + } + } + impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); // lint !e665 +} + +TensorDesc Tensor::GetTensorDesc() const { + if (impl != nullptr) { + return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); + } + return TensorDesc(); +} + +graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { + if (impl != nullptr) { + impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +const uint8_t *Tensor::GetData() const { + if (impl != nullptr) { + return impl->ge_tensor.GetData().data(); + } + return nullptr; +} + +uint8_t *Tensor::GetData() { + if (impl != nullptr) { + return impl->ge_tensor.MutableData().data(); + } + return nullptr; +} + +size_t Tensor::GetSize() const { + if (impl != nullptr) { + return impl->ge_tensor.GetData().size(); + } + return 0; +} + +graphStatus Tensor::SetData(std::vector &&data) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::vector &data) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const uint8_t *data, size_t size) { + if (impl != nullptr) { + (void)impl->ge_tensor.SetData(data, size); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::string &data) { + if (impl != nullptr) { + if (impl->SetData(data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Tensor set data failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::vector &data) { + if (impl != nullptr) { + if (impl->SetData(data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Tensor set vector data failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const char *data) { + if (impl != nullptr && data != nullptr) { + std::string tensor_data = data; + if (impl->SetData(tensor_data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Tensor set data failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::SetData(const std::vector &datas) { + if (impl != nullptr) { + std::vector tensor_data; + for (auto &data : datas) { + if (data.GetString() == nullptr) { + GELOGE(GRAPH_FAILED, "Data is nullptr."); + return GRAPH_FAILED; + } + tensor_data.emplace_back(data.GetString()); + } + if (impl->SetData(tensor_data) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Tensor set vector data failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Tensor::IsValid() { + uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize(); + DataType data_type = GetTensorDesc().GetDataType(); + uint32_t type_length; + bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); + if (!ret) { + GELOGW("datatype %d is not found.", data_type); + return GRAPH_SUCCESS; + } + + size_t data_size = GetSize(); + if (data_type != DT_STRING) { + if (shape_size || (data_size != type_length)) { + if (type_length != 0 && UINT64_MAX / type_length < shape_size) { + GELOGW("mul overflow: %lu, %u", shape_size, type_length); + } else { + if (shape_size * type_length != data_size) { + GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, + data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); + return GRAPH_FAILED; + } + } + } + } + + return GRAPH_SUCCESS; +} + +Tensor Tensor::Clone() const { + Tensor tensor; + if (impl != nullptr && tensor.impl != nullptr) { + tensor.impl->ge_tensor = impl->ge_tensor.Clone(); + } + return tensor; +} + +GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { + GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), + tensor_desc.GetDataType()); + ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); + ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); + ge_tensor_desc.SetName(tensor_desc.GetName()); + std::vector> shape_range; + auto status = tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return ge_tensor_desc; + } + status = ge_tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return ge_tensor_desc; + } + auto size = tensor_desc.GetSize(); + TensorUtils::SetSize(ge_tensor_desc, size); + + auto real_dim_cnt = static_cast(tensor_desc.GetRealDimCnt()); + TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); + return ge_tensor_desc; +} + +TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { + TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), + ge_tensor_desc.GetDataType()); + tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); + tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); + tensor_desc.SetName(ge_tensor_desc.GetName()); + std::vector> shape_range; + auto status = ge_tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return tensor_desc; + } + status = tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return tensor_desc; + } + int64_t size = 0; + (void)TensorUtils::GetSize(ge_tensor_desc, size); + tensor_desc.SetSize(size); + + uint32_t real_dim_cnt = 0; + (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); + tensor_desc.SetRealDimCnt(real_dim_cnt); + return tensor_desc; +} + +GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); // lint !e665 + } + return ge_tensor; +} + +Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { + Tensor tensor; + if (ge_tensor != nullptr && tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor->Clone(); + } + return tensor; +} + +ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 + } + return ge_tensor; +} + +GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { + GeTensorPtr ge_tensor; + if (tensor.impl != nullptr) { + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 + } + return ge_tensor; +} + +const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { + if (tensor.impl != nullptr) { + return tensor.impl->ge_tensor; + } + return GeTensor(); +} + +GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { + if (tensor.impl != nullptr) { + return tensor.impl->ge_tensor; + } + return GeTensor(); +} + +const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { + Tensor tensor; + if (tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor; + } + return tensor; +} + +Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { + Tensor tensor; + if (tensor.impl != nullptr) { + tensor.impl->ge_tensor = ge_tensor; + } + return tensor; +} +} // namespace ge diff --git a/metadef/graph/utils/anchor_utils.cc b/metadef/graph/utils/anchor_utils.cc new file mode 100644 index 00000000..5a042283 --- /dev/null +++ b/metadef/graph/utils/anchor_utils.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/anchor_utils.h" +#include +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) { + if (data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); + return FORMAT_RESERVED; + } + return data_anchor->format_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor, + Format data_format) { + if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) { + GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); + return GRAPH_FAILED; + } + data_anchor->format_ = data_format; + return GRAPH_SUCCESS; +} + +// Get anchor status +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { + if (data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); + return ANCHOR_RESERVED; + } + return data_anchor->status_; +} + +// Set anchor status +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, + AnchorStatus anchor_status) { + if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { + GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); + return GRAPH_FAILED; + } + data_anchor->status_ = anchor_status; + return GRAPH_SUCCESS; +} + +bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) { + auto control_anchor = Anchor::DynamicAnchorCast(anchor); + if (control_anchor != nullptr) { + return (control_anchor->GetPeerAnchors().size() != 0); + } + + auto data_anchor = Anchor::DynamicAnchorCast(anchor); + if (data_anchor) { + for (const auto &peer : data_anchor->GetPeerAnchors()) { + auto peer_cast = Anchor::DynamicAnchorCast(peer); + if (peer_cast) { + return true; + } + } + return false; + } + GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor"); + return false; +} + +bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) { + GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null."); + GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false); + auto src_control_anchor = Anchor::DynamicAnchorCast(src); + auto dst_control_anchor = Anchor::DynamicAnchorCast(dst); + return (src_control_anchor || dst_control_anchor); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) { + // Check if it can add edge between DataAnchor + auto data_anchor = Anchor::DynamicAnchorCast(anchor); + if (data_anchor != nullptr) { + return data_anchor->GetIdx(); + } + // Check if it can add edge between ControlAnchor + auto control_anchor = Anchor::DynamicAnchorCast(anchor); + if (control_anchor != nullptr) { + return control_anchor->GetIdx(); + } + return -1; +} +} // namespace ge diff --git a/metadef/graph/utils/ge_ir_utils.cc b/metadef/graph/utils/ge_ir_utils.cc new file mode 100644 index 00000000..5d8430aa --- /dev/null +++ b/metadef/graph/utils/ge_ir_utils.cc @@ -0,0 +1,1193 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/ge_ir_utils.h" +#include +#include "framework/common/debug/ge_log.h" +#include "mmpa/mmpa_api.h" + +namespace { +const char *const kControlAnchorIndex = ":-1"; +const char *const kNodeTypeForSubgraph = "subgraph"; +const char *const kPrefixForInputDesc = "input_desc_attr_"; +const char *const kPrefixForOutputDesc = "output_desc_attr_"; +const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; +const int8_t kMaxRecursionDepth = 10; +const int kBase = 10; +char kDumpGeGraph[MMPA_MAX_PATH] = { 0x00 }; +const int64_t kDumpLevel = + (mmGetEnv(kDumpGEGraph, kDumpGeGraph, MMPA_MAX_PATH) == EN_OK) ? + std::strtol(kDumpGeGraph, nullptr, kBase) : ge::OnnxUtils::NO_DUMP; +const int64_t kInputPrefixLength = 5; +const int64_t kOutputPrefixLength = 6; +using AttrDefPair = ::google::protobuf::MapPair; +} // namespace + +namespace ge { +// Part 1: from IR convert to ONNX Protobuf +namespace{ +const std::map kGeDataTypeToOnnxMap = { + {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, + {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, + {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, + {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, + {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, + {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, +}; +} + +struct AttrNameComp { + inline bool operator()(const onnx::AttributeProto &lsh, const onnx::AttributeProto &rsh) { + return lsh.name() < rsh.name(); + } +}; + +onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { + auto it = kGeDataTypeToOnnxMap.find(data_type); + if (it != kGeDataTypeToOnnxMap.end()) { + return it->second; + } else { + GELOGW("EncodeDataType: datatype not support %u", data_type); + return onnx::TensorProto_DataType_UNDEFINED; + } +} + +void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, + onnx::NodeProto *node_proto) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto is nullptr."); + return; + } + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + auto attr_name = string_attr_value.first; + attr->set_name(attr_name); + auto attr_value = string_attr_value.second; + auto value_type = attr_value.GetValueType(); + switch (value_type) { + case GeAttrValue::VT_FLOAT: { + GeAttrValue::FLOAT data_f = 0; + (void)attr_value.GetValue(data_f); + attr->set_f(data_f); + attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); + break; + } + case GeAttrValue::VT_LIST_FLOAT: { + GeAttrValue::LIST_FLOAT data_fs = {}; + (void)attr_value.GetValue(data_fs); + attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); + for (auto &v : data_fs) { + attr->add_floats(v); + } + break; + } + case GeAttrValue::VT_INT: { + GeAttrValue::INT data_i = 0; + (void)attr_value.GetValue(data_i); + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i(data_i); + break; + } + case GeAttrValue::VT_LIST_INT: { + GeAttrValue::LIST_INT data_is = {}; + (void)attr_value.GetValue(data_is); + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for (auto &v : data_is) { + attr->add_ints(v); + } + break; + } + case GeAttrValue::VT_STRING: { + GeAttrValue::STR data_s; + (void)attr_value.GetValue(data_s); + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s(data_s); + break; + } + case GeAttrValue::VT_LIST_STRING: { + GeAttrValue::LIST_STR data_ss = {}; + (void)attr_value.GetValue(data_ss); + attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); + for (auto &v : data_ss) { + attr->add_strings(v); + } + break; + } + default: + GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type); + break; + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + void *data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + switch (type) { + case onnx::AttributeProto_AttributeType_FLOAT: + attr->set_f((*(static_cast(data)))); + attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); + break; + + case onnx::AttributeProto_AttributeType_FLOATS: + attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); + for (auto &v : (*(static_cast *>(data)))) { + attr->add_floats(v); + } + break; + + case onnx::AttributeProto_AttributeType_INT: + attr->set_type(onnx::AttributeProto_AttributeType_INT); + attr->set_i((*(static_cast(data)))); + break; + + case onnx::AttributeProto_AttributeType_INTS: + attr->set_type(onnx::AttributeProto_AttributeType_INTS); + for (auto &v : *(static_cast *>(data))) { + attr->add_ints(v); + } + break; + + case onnx::AttributeProto_AttributeType_STRING: + attr->set_type(onnx::AttributeProto_AttributeType_STRING); + attr->set_s((*(static_cast(data)))); + break; + + case onnx::AttributeProto_AttributeType_STRINGS: + attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); + for (auto &v : *(static_cast *>(data))) { + attr->add_strings(v); + } + break; + + default: + GELOGW("AttributeProto AttributeType: %u is not supported for now", type); + break; + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_ints(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_ints(static_cast(v)); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedField data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_floats(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, + ::google::protobuf::RepeatedPtrField<::std::string> data) { + if (node_proto == nullptr) { + GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); + return; + } + if (!data.empty()) { + auto attr = node_proto->add_attribute(); + if (attr == nullptr) { + GELOGE(GRAPH_FAILED, "attr is nullptr."); + return; + } + attr->set_name(name); + for (auto &v : data) { + attr->add_strings(v); + } + attr->set_type(type); + } +} + +void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) { + if (node_proto == nullptr || op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr"); + return; + } + // Input describes + auto size_in = op_desc->GetAllInputsSize(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in); + if (size_in > 0) { + for (uint32_t i = 0; i < size_in; i++) { + auto input_desc = op_desc->GetInputDescPtrDfault(i); + if (input_desc != nullptr) { + auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_dtype:" + std::to_string(i), &data_type); + auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); + auto dims = input_desc->GetShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "input_desc_shape:" + std::to_string(i), &dims); + auto dims_origin = input_desc->GetOriginShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "input_desc_origin_shape:" + std::to_string(i), &dims_origin); + auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_layout:" + std::to_string(i), &layout); + auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_origin_layout:" + std::to_string(i), &layout_origin); + auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor != nullptr) { + auto size = tensor_descriptor->size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_size:" + std::to_string(i), &size); + auto weight_size = tensor_descriptor->weight_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_weight_size:" + std::to_string(i), &weight_size); + auto reuse_input = tensor_descriptor->reuse_input(); + auto reuse_input_int = static_cast(reuse_input); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int); + auto output_tensor = tensor_descriptor->output_tensor(); + auto output_tensor_int = static_cast(output_tensor); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int); + auto device_type = tensor_descriptor->device_type(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_device_type:" + std::to_string(i), &device_type); + auto input_tensor = tensor_descriptor->input_tensor(); + auto input_tensor_int = static_cast(input_tensor); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int); + auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); + auto data_offset = tensor_descriptor->data_offset(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_data_offset:" + std::to_string(i), &data_offset); + auto cmps_size = tensor_descriptor->cmps_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i), + &cmps_size); + auto cmps_tab = tensor_descriptor->cmps_tab(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab); + auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); + const auto &tensor_desc_map = tensor_descriptor->attr(); + std::string suffix = ":" + std::to_string(i); + AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); + } else { + GELOGW("Tensor descriptor is nullptr"); + continue; + } + } else { + GELOGW("Input desc is nullptr"); + continue; + } + } + } + // Output describes + auto size_out = op_desc->GetOutputsSize(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out); + if (size_out > 0) { + for (uint32_t i = 0; i < size_out; i++) { + auto output_desc = op_desc->GetOutputDescPtr(i); + if (output_desc != nullptr) { + auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_dtype:" + std::to_string(i), &data_type); + auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); + auto dims = output_desc->GetShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "output_desc_shape:" + std::to_string(i), &dims); + auto dims_origin = output_desc->GetOriginShape().GetDims(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + "output_desc_origin_shape:" + std::to_string(i), &dims_origin); + auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i), + &layout); + auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_origin_layout:" + std::to_string(i), &layout_origin); + auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg(); + if (tensor_descriptor != nullptr) { + auto size = tensor_descriptor->size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i), + &size); + auto weight_size = tensor_descriptor->weight_size(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "output_desc_weight_size:" + std::to_string(i), &weight_size); + auto device_type = tensor_descriptor->device_type(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + "output_desc_device_type:" + std::to_string(i), &device_type); + auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, + "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); + const auto &tensor_desc_map = tensor_descriptor->attr(); + std::string suffix = ":" + std::to_string(i); + AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); + } else { + GELOGW("Tensor descriptor is nullptr"); + continue; + } + } else { + GELOGW("Output desc is nullptr"); + continue; + } + } + } +} + +void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( + const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, + const std::string& prefix, const std::string& suffix) { + for (const auto &item : attr_map) { + auto attr_name = item.first; + auto attr_def = item.second; + auto attr_type = attr_def.value_case(); + if (attr_type == ge::proto::AttrDef::kT) { + const auto &tensor_def = attr_def.t(); + const auto &tensor_desc = tensor_def.desc(); + auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + prefix + attr_name + "_desc_dtype" + suffix, &data_type); + auto dims = tensor_desc.shape().dim(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, + prefix + attr_name + "_desc_shape" + suffix, dims); + auto layout = tensor_desc.layout(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + prefix + attr_name + "_desc_layout" + suffix, &layout); + auto device_type = tensor_desc.device_type(); + AddAttrProto(node_proto, ge::onnx::AttributeProto_AttributeType_STRING, + prefix + attr_name + "_desc_device_type" + suffix, &device_type); + if (kDumpLevel == DUMP_ALL) { + auto data = tensor_def.data(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, + prefix + attr_name + "_data" + suffix, &data); + } + } + if (attr_type == ge::proto::AttrDef::kS) { + if (kDumpLevel == DUMP_ALL) { + auto str_value = attr_def.s(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); + } + } + if (attr_type == ge::proto::AttrDef::kI) { + auto int_value = attr_def.i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); + } + if (attr_type == ge::proto::AttrDef::kF) { + auto float_value = attr_def.f(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); + } + if (attr_type == ge::proto::AttrDef::kB) { + auto int_value = static_cast(attr_def.b()); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); + } + if (attr_type == ge::proto::AttrDef::kList) { + const auto &list_value = attr_def.list(); + auto list_value_type = list_value.val_type(); + if (list_value_type == + ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { + if (kDumpLevel == DUMP_ALL) { + const auto &strings = list_value.s(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); + } + } + if (list_value_type == + ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { + const auto &floats = list_value.f(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); + } + if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { + const auto &ints = list_value.i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); + } + if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { + const auto &bools = list_value.b(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); + } + } + } +} + +void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "node is nullptr"); + return; + } + // 1.Attributes added from node's methods + auto send_list = node->send_event_id_list_; + if (!send_list.empty()) { + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); + } + auto recv_list = node->recv_event_id_list_; + if (!recv_list.empty()) { + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); + } + auto op_desc = node->op_; + if (op_desc != nullptr) { + // for input_name_idx_ in opdesc + auto input_name_2_indexs = op_desc->GetAllInputName(); + ::google::protobuf::RepeatedPtrField<::std::string> input_names; + ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; + for (const auto &input_name_2_index: input_name_2_indexs) { + std::string input_name = input_name_2_index.first; + input_names.Add(std::move(input_name)); + input_indexes.Add(input_name_2_index.second); + } + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); + // 2.Attributes added from node's op_(message OpDef) + // Input and out describes + AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); + // Others + auto op_def = op_desc->op_def_.GetProtoMsg(); + if (op_def != nullptr) { + auto id = op_def->id(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); + auto stream_id = op_def->stream_id(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); + const auto &input_name = op_def->input_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name); + const auto &src_name = op_def->src_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name); + const auto &src_index = op_def->src_index(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index); + const auto &dst_name = op_def->dst_name(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name); + const auto &dst_index = op_def->dst_index(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index); + const auto &input_i = op_def->input_i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i); + const auto &output_i = op_def->output_i(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i); + const auto &workspace = op_def->workspace(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace); + const auto &workspace_bytes = op_def->workspace_bytes(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); + const auto &is_input_const = op_def->is_input_const(); + AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); + const auto &op_def_attr_map = op_def->attr(); + AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); + } else { + GELOGE(FAILED, "Opdef is nullptr"); + return; + } + } else { + GELOGE(FAILED, "Opdesc is nullptr"); + return; + } +} + +bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid"); + return false; + } + + // 2.Encode map attrs_ to AttributeProto + for (auto &node_attr : node->attrs_) { + AddAttrProtoFromAttribute(node_attr, node_proto); + } + // 3.Encode ge::Node members to AttributeProto + AddAttrProtoFromNodeMembers(node, node_proto); + + // 4. Sort node attributes by name. + std::sort(node_proto->mutable_attribute()->begin(), node_proto->mutable_attribute()->end(), AttrNameComp()); + return true; +} + +void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); + return; + } + const auto &node_name = node->GetName(); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { + node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); + } + } + auto out_control_anchor = node->GetOutControlAnchor(); + if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { + node_proto->add_output(node_name + kControlAnchorIndex); + } +} + +bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid"); + return false; + } + node_proto->clear_input(); + // 1. Add input by in data edge + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { + node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + + std::to_string(peer_out_anchor->GetIdx())); + } else { + // Add "" input + node_proto->add_input(""); + } + } + + // 2. Add input by in control edge + auto in_control_anchor = node->GetInControlAnchor(); + if (in_control_anchor != nullptr) { + auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor->GetOwnerNode()) { + node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); + } + } + } else { + GELOGE(FAILED, "Incontrol anchor is nullptr"); + return false; + } + + // 3. Add output for Netron visual support + EncodeNodeLinkForNetronVisual(node, node_proto); + return true; +} + +bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) { + if ((node == nullptr) || (node_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid"); + return false; + } + // 1. Encode name and type + node_proto->set_name(node->GetName()); + /// Netron believes that some operators, such as the activation operator of softplus, only have one input, + /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix + /// is added to correctly display the link relation at the expense of some color features + node_proto->set_op_type("ge:" + node->GetType()); + + if (kDumpLevel != DUMP_WITH_OUT_DESC) { + // 2.for attr + if (!EncodeNodeDesc(node, node_proto)) { + GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str()); + return false; + } + } + // 3.for link info + return EncodeNodeLink(node, node_proto); +} + +void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) { + if ((node == nullptr) || (tensor_type == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); + return; + } + const auto &op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + uint32_t size_out = static_cast(op_desc->GetOutputsSize()); + if (size_out > 0) { + for (uint32_t i = 0; i < size_out; i++) { + const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); + if (ge_tensor != nullptr) { + auto ge_data_type = ge_tensor->GetDataType(); + auto onnx_data_type = EncodeDataType(ge_data_type); + tensor_type->set_elem_type(onnx_data_type); + onnx::TensorShapeProto *shape = tensor_type->mutable_shape(); + if (shape != nullptr) { + for (auto d : ge_tensor->GetShape().GetDims()) { + auto dim = shape->add_dim(); + dim->set_dim_value(d); + } + } else { + GELOGW("Shape is nullptr"); + continue; + } + } else { + GELOGW("Ge tensor is nullptr"); + continue; + } + } + } + } else { + GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str()); + return; + } +} + +void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) { + if ((node == nullptr) || (value_info_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid"); + return; + } + value_info_proto->set_name(node->GetName()); + onnx::TypeProto *t = value_info_proto->mutable_type(); + onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type(); + EncodeTypeProtoTensorType(node, tensor_type); +} + +bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) { + if ((graph == nullptr) || (graph_proto == nullptr)) { + GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid"); + return false; + } + graph_proto->set_name(graph->GetName()); + // 1. Add graph inputs + for (const auto &input : graph->GetInputNodes()) { + auto value_info_proto = graph_proto->add_input(); + EncodeValueInfo(input, value_info_proto); + } + // 2. Add graph outputs + for (const auto &output : graph->GetOutputNodes()) { + auto value_info_proto = graph_proto->add_output(); + EncodeValueInfo(output, value_info_proto); + } + // 3. Add nodes + for (const auto &node : graph->GetDirectNode()) { + if (!EncodeNode(node, graph_proto->add_node())) { + GELOGW("EncodeNode failed"); + continue; + } + } + return true; +} + +bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { + model_proto.set_model_version(model.GetVersion()); + model_proto.set_ir_version(onnx::IR_VERSION); + model_proto.set_producer_name(model.GetName()); + auto &graph = model.graph_; + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr"); + return false; + } + auto graph_proto = model_proto.mutable_graph(); + if (graph_proto == nullptr) { + GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str()); + return false; + } + if (!EncodeGraph(compute_graph, graph_proto)) { + GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str()); + return false; + } + + // For subgraphs: a subgraph is represented by a node + for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { + if (sub_compute_graph != nullptr) { + auto node_proto = graph_proto->add_node(); + if (node_proto == nullptr) { + GELOGW("Node proto is nullptr"); + continue; + } + node_proto->set_name(sub_compute_graph->GetName()); + node_proto->set_op_type(kNodeTypeForSubgraph); + auto attr = node_proto->add_attribute(); + attr->set_name("graph"); + attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); + auto sub_graph_proto = attr->mutable_g(); + if (sub_graph_proto == nullptr) { + GELOGW("Sub graph proto is nullptr"); + continue; + } + if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { + GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); + continue; + } + } else { + GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); + continue; + } + } + return true; +} + +// Part 2: from ONNX Protobuf convert to IR +static std::map onnxDataTypeToGeMap = { + {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, + {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, + {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, + {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, + {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, + {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, +}; + +ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { + auto it = onnxDataTypeToGeMap.find(data_type); + if (it != onnxDataTypeToGeMap.end()) { + return it->second; + } else { + GELOGW("DecodeDataType: datatype not support %u", data_type); + return ge::DT_UNDEFINED; + } +} + +bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) { + auto sep = node_name_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + node_name = node_name_index.substr(0, sep); + auto index_str = node_name_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); + return true; +} + +bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr"); + return false; + } + // Data edge + if (item.src_out_index >= 0) { + auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index); + auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); + if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { + GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed"); + return false; + } + // Control edge + } else { + auto src_anchor = node_ptr->GetOutControlAnchor(); + auto dst_anchor = item.dst_node->GetInControlAnchor(); + if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { + GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, + item.dst_node_name.c_str(), item.dst_in_index); + return false; + } + if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed"); + return false; + } + } + return true; +} + +bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, + const std::map &node_map) { + for (const auto &node_proto : node_proto_vector) { + const auto &node_name = node_proto.name(); + auto dst_node = node_map.find(node_name); + if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { + GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str()); + return false; + } + int32_t dst_index = 0; + for (const auto &input : node_proto.input()) { + std::string input_node_name; + int32_t index = 0; + if (ParseNameIndex(input, input_node_name, index)) { + auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; + auto src_node = node_map.find(input_node_name); + if (src_node == node_map.end()) { + GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str()); + return false; + } + auto node_ptr = src_node->second; + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str()); + return false; + } + if (!DecodeNodeLinkImp(item, node_ptr)) { + GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str()); + return false; + } + } + if (index >= 0) { + dst_index++; + } + } + } + return true; +} + +void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &strings) { + if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRINGS) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + for (int i = 0; i < attr_proto.strings_size(); i++) { + strings.push_back(attr_proto.strings(i)); + } +} + +void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::string &value) { + if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRING) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + value = attr_proto.s(); +} + +void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &ints) { + if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INTS) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + for (int i = 0; i < attr_proto.ints_size(); i++) { + ints.push_back(attr_proto.ints(i)); + } +} + +void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, int64_t &value) { + if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INT) { + GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); + return; + } + value = attr_proto.i(); +} + +void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc->MutableInputDesc(static_cast(index)) == nullptr) { + GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast(index)) is nullptr", + op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); + return; + } + if (attr_name_for_input_desc == "input_desc_dtype") { + auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetDataType(data_type); + } else if (attr_name_for_input_desc == "input_desc_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableInputDesc(static_cast(index))->SetShape(ge_shape); + } else if (attr_name_for_input_desc == "input_desc_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetFormat(data_format); + } else if (attr_name_for_input_desc == "input_desc_origin_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableInputDesc(static_cast(index))->SetOriginShape(ge_shape); + } else if (attr_name_for_input_desc == "input_desc_origin_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableInputDesc(static_cast(index))->SetOriginFormat(data_format); + } else if (attr_name_for_input_desc == "input_desc_size") { + int64_t input_size = 0; + auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + DecodeAttribute(attr_proto, input_size); + tensor_descriptor->set_size(input_size); + } else if (attr_name_for_input_desc == "input_desc_data_offset") { + auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + int64_t offset = 0; + DecodeAttribute(attr_proto, offset); + tensor_descriptor->set_data_offset(offset); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_output_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc->MutableOutputDesc(static_cast(index)) == nullptr) { + GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast(index)) is nullptr", + op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); + return; + } + if (attr_name_for_output_desc == "output_desc_dtype") { + auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetDataType(data_type); + } else if (attr_name_for_output_desc == "output_desc_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableOutputDesc(static_cast(index))->SetShape(ge_shape); + } else if (attr_name_for_output_desc == "output_desc_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetFormat(data_format); + } else if (attr_name_for_output_desc == "output_desc_origin_shape") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + GeShape ge_shape(ints); + op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(ge_shape); + } else if (attr_name_for_output_desc == "output_desc_origin_layout") { + auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); + op_desc->MutableOutputDesc(static_cast(index))->SetOriginFormat(data_format); + } else if (attr_name_for_output_desc == "output_desc_size") { + int64_t output_size = 0; + auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + DecodeAttribute(attr_proto, output_size); + tensor_descriptor->set_size(output_size); + } else if (attr_name_for_output_desc == "output_desc_data_offset") { + auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); + int64_t offset = 0; + DecodeAttribute(attr_proto, offset); + tensor_descriptor->set_data_offset(offset); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_output_desc, int32_t index, + OpDescPtr &op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is nullptr"); + return; + } + if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { + DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { + DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } else { + return; + } +} + +void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { + auto attr_map = op_def.mutable_attr(); + const auto &attr_name = attr_proto.name(); + ge::proto::AttrDef op_attr; + int64_t value = 0; + DecodeAttribute(attr_proto, value); + op_attr.set_i(value); + attr_map->insert(AttrDefPair(attr_name, op_attr)); +} + +void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); + return; + } + const auto &attr_name = attr_proto.name(); + std::string attr_name_for_input_output_desc; + int32_t index = 0; + if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) { + if (attr_name == "id") { + op_desc->SetId(attr_proto.i()); + } else if (attr_name == "stream_id") { + op_desc->SetStreamId(attr_proto.i()); + } else if (attr_name == "src_name") { + std::vector strings; + DecodeAttribute(attr_proto, strings); + op_desc->SetSrcName(strings); + } else if (attr_name == "dst_name") { + std::vector strings; + DecodeAttribute(attr_proto, strings); + op_desc->SetDstName(strings); + } else if (attr_name == "src_index") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetSrcIndex(ints); + } else if (attr_name == "dst_index") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetDstIndex(ints); + } else if (attr_name == "fusion_scope") { + DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); + } else if (attr_name == "input_i") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetInputOffset(ints); + } else if (attr_name == "output_i") { + std::vector ints; + DecodeAttribute(attr_proto, ints); + op_desc->SetOutputOffset(ints); + } else { + return; + } + // Update input and output desc + } else { + DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); + } +} + +bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) { + if (op_desc == nullptr || node_proto == nullptr) { + GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr"); + return false; + } + // 1. Decode node_proto name and type + op_desc->SetName(node_proto->name()); + const auto &node_type_with_ge_prefix = node_proto->op_type(); + auto sep = node_type_with_ge_prefix.find(':'); + if (sep == std::string::npos) { + return false; + } + auto node_type = node_type_with_ge_prefix.substr(sep + 1); + op_desc->SetType(node_type); + // 2. Add empty input and output desc + for (const auto &attr : node_proto->attribute()) { + if (attr.name() == "input_desc_nums") { + auto size_in = attr.i(); + for (int64_t i = 0; i < size_in; i++) { + GeTensorDesc ge_tensor_desc; + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); + } + } + if (attr.name() == "output_desc_nums") { + auto size_out = attr.i(); + for (int64_t i = 0; i < size_out; i++) { + GeTensorDesc ge_tensor_desc; + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); + } + } + } + // 3.Decode node_proto attributes + for (int i = 0; i < node_proto->attribute_size(); i++) { + DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); + } + return true; +} + +bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { + if (recursion_depth > kMaxRecursionDepth) { + GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort"); + return false; + } + + graph = ComGraphMakeShared(graph_proto.name()); + GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); + /// 1. Decode all nodes first, node should include input + /// and output nodes and nodes which represent sub graphs + std::map node_map; + std::vector node_proto_vector; + for (const auto &node_proto : graph_proto.node()) { + // a. nodes represent sub graphs + if (node_proto.op_type() == kNodeTypeForSubgraph) { + ComputeGraphPtr compute_graph; + // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH + const auto &node_attr = node_proto.attribute(0); + if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && + DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) { + (void)graph->AddSubGraph(compute_graph); + } else { + GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), + node_attr.type()); + return false; + } + // b. direct nodes in graph + } else { + node_proto_vector.push_back(node_proto); + OpDescPtr op_desc = ComGraphMakeShared(); + // b.1 For node desc + if (!DecodeNodeDesc(&node_proto, op_desc)) { + GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str()); + return false; + } + auto node = graph->AddNode(op_desc); + node_map.insert(std::make_pair(node_proto.name(), node)); + } + } + /// We get all nodes in graph here + /// b.2 For node link + if (!DecodeNodeLink(node_proto_vector, node_map)) { + GELOGE(GRAPH_FAILED, "Decode node link failed"); + return false; + } + + // 2. Add inputs nodes for graph + for (const auto &input : graph_proto.input()) { + const auto &input_node_name = input.name(); + auto input_node_item = node_map.find(input_node_name); + if (input_node_item == node_map.end()) { + GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str()); + return false; + } + auto ret = graph->AddInputNode(input_node_item->second); + GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed"); + } + // 3. Add outputs nodes for graph + for (const auto &output : graph_proto.output()) { + const auto &output_node_name = output.name(); + auto output_node_item = node_map.find(output_node_name); + if (output_node_item == node_map.end()) { + GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str()); + return false; + } + auto ret = graph->AddOutputNode(output_node_item->second); + if (ret == nullptr) { + GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str()); + continue; + } + } + return true; +} + +bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) { + model.name_ = model_proto.producer_name(); + model.version_ = static_cast(model_proto.model_version()); + + auto &graph_proto = model_proto.graph(); + ComputeGraphPtr compute_graph; + // 0 means recursion depth, father call + if (!DecodeGraph(0, graph_proto, compute_graph)) { + GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed"); + return false; + } + model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + return true; +} +} // namespace ge diff --git a/metadef/graph/utils/ge_ir_utils.h b/metadef/graph/utils/ge_ir_utils.h new file mode 100644 index 00000000..4c54f171 --- /dev/null +++ b/metadef/graph/utils/ge_ir_utils.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ +#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "proto/ge_ir.pb.h" +#include "proto_inner/ge_onnx.pb.h" + +namespace ge { +const int kOffsetToString = 2; + +/// +/// @ingroup ge_ir_utils +/// @brief RepeatedField->String +/// @param [in] const rpd_field RepeatedField +/// @return String +/// +template +const std::string ToString(const google::protobuf::RepeatedField &rpd_field) { + std::stringstream ss; + ss << "["; + for (const T &x : rpd_field) { + ss << x; + ss << ", "; + } + std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); + str_ret += "]"; + return str_ret; +} + +/// +/// @ingroup ge_ir_utils +/// @brief RepeatedPtrField->String +/// @param [in] const rpd_field RepeatedPtrField +/// @return String +/// +template +const std::string ToString(const google::protobuf::RepeatedPtrField &rpd_ptr_field) { + std::stringstream ss; + ss << "["; + for (const T &x : rpd_ptr_field) { + ss << x; + ss << ", "; + } + std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); + str_ret += "]"; + return str_ret; +} + +/// +/// @ingroup ge_ir_utils +/// @brief check, if not equal, log with tag +/// @param [in] const left_value, right_value reference, log_info_tag +/// @return bool +/// +template +bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { + if (l_value == r_value) { + return true; + } else { + GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); + return false; + } +} + +class OnnxUtils { + public: + enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; + + static bool ConvertGeModelToModelProto(const ge::Model &model, ge::onnx::ModelProto &model_proto); + + static bool ConvertModelProtoToGeModel(const ge::onnx::ModelProto &model_proto, ge::Model &model); + + private: + // Part 1: from IR convert to ONNX Protobuf + static void AddAttrProto(ge::onnx::NodeProto *node_proto, ge::onnx::AttributeProto_AttributeType type, + const std::string &name, void *data); + + static void AddAttrProto(ge::onnx::NodeProto *node_proto, ge::onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); + + static void AddAttrProto(ge::onnx::NodeProto *node_proto, ge::onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField data); + + static void AddAttrProto(ge::onnx::NodeProto *node_proto, ge::onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedField data); + + static void AddAttrProto(ge::onnx::NodeProto *node_proto, ge::onnx::AttributeProto_AttributeType type, + const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); + + static void AddAttrProtoFromNodeMembers(const NodePtr &node, ge::onnx::NodeProto *node_proto); + + static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, + ge::onnx::NodeProto *node_proto); + + static void AddAttrProtoForOpInAndOutDesc(ge::onnx::NodeProto *node_proto, const OpDescPtr &op_desc); + + static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, + ge::onnx::NodeProto *node_proto, + const std::string& prefix = "", + const std::string& suffix = ""); + + static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, ge::onnx::NodeProto *node_proto); + + static ge::onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); + + static void EncodeNodeLinkForNetronVisual(const NodePtr &node, ge::onnx::NodeProto *node_proto); + + static bool EncodeNodeLink(const NodePtr &node, ge::onnx::NodeProto *node_proto); + + static bool EncodeNodeDesc(const NodePtr &node, ge::onnx::NodeProto *node_proto); + + static bool EncodeNode(const NodePtr &node, ge::onnx::NodeProto *node_proto); + + static void EncodeTypeProtoTensorType(const NodePtr &node, ge::onnx::TypeProto_Tensor *tensor_type); + + static void EncodeValueInfo(const NodePtr &n, ge::onnx::ValueInfoProto *v); + + static bool EncodeGraph(const ConstComputeGraphPtr &graph, ge::onnx::GraphProto *graph_proto); + + /// Part 2: from ONNX Protobuf convert to IR + /// Describes node's link relationships + struct NodeLinkInfo { + std::string src_node_name; + int32_t src_out_index; + NodePtr dst_node; + int32_t dst_in_index; + std::string dst_node_name; + }; + + // Parse node name and index + static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); + + static ge::DataType DecodeDataType(ge::onnx::TensorProto_DataType data_type); + + static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &strings); + + static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &ints); + + static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, int64_t &value); + + static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::string &value); + + static void DecodeNodeAttributeForOpOutDesc(const ge::onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_output_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpInDesc(const ge::onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpInAndOutDesc(const ge::onnx::AttributeProto &attr_proto, + const std::string &attr_name_for_input_output_desc, int32_t index, + OpDescPtr &op_desc); + + static void DecodeNodeAttributeForOpDef(const ge::onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); + + static void DecodeNodeAttributeForOpDesc(const ge::onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); + + static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); + + static bool DecodeNodeLink(const std::vector &node_proto_vector, + const std::map &node_map); + + static bool DecodeNodeDesc(const ge::onnx::NodeProto *node_proto, OpDescPtr &node); + + static bool DecodeGraph(int recursion_depth, const ge::onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); +}; +} // namespace ge + +#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/metadef/graph/utils/graph_utils.cc b/metadef/graph/utils/graph_utils.cc new file mode 100644 index 00000000..dad2d487 --- /dev/null +++ b/metadef/graph/utils/graph_utils.cc @@ -0,0 +1,3029 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/graph_utils.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./ge_context.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "proto/ge_ir.pb.h" +#include "utils/attr_utils.h" +#include "utils/ge_ir_utils.h" +#include "utils/node_utils.h" +#include "debug/ge_op_types.h" +#include "external/ge/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "mmpa/mmpa_api.h" + +using google::protobuf::io::FileOutputStream; + +namespace ge { +enum DumpGraphLevel { + kDumpLevel1 = 1, + kDumpLevel2 = 2, + kDumpLevel3 = 3, + kDumpLevelOther, +}; + +namespace{ +const int32_t kBaseOfIntegerValue = 10; +#ifdef FMK_SUPPORT_DUMP +const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; +const int kDumpGraphIndexWidth = 8; +#endif + +const char *const kDumpGraphPath = "DUMP_GRAPH_PATH"; +const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; +const char *const kDumpStrBuild = "Build"; +const char *const kDumpStrPartition = "partition"; +const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; +const char *const kDumpStrSubgraphFunc = "sub_graph"; +const char *const kDumpStrAicpu = "Aicpu"; +const int32_t kNameMax = 255; +}; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const AnchorPtr &src, + const AnchorPtr &dst) { + OutDataAnchorPtr src_data = Anchor::DynamicAnchorCast(src); + InDataAnchorPtr dst_data = Anchor::DynamicAnchorCast(dst); + OutControlAnchorPtr src_control = Anchor::DynamicAnchorCast(src); + InControlAnchorPtr dst_control = Anchor::DynamicAnchorCast(dst); + if ((src_data != nullptr) && (dst_data != nullptr) && (src_data->LinkTo(dst_data) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_data != nullptr) && (dst_control != nullptr) && (src_data->LinkTo(dst_control) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_control != nullptr) && (dst_control != nullptr) && (src_control->LinkTo(dst_control) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + if ((src_control != nullptr) && (dst_data != nullptr) && (src_control->LinkTo(dst_data) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const Format &src_format, + const InDataAnchorPtr &dst, + const Format &dst_format) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + auto ret = AnchorUtils::SetFormat(src, src_format); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed, format is %d", static_cast(src_format)); + return ret; + } + ret = AnchorUtils::SetFormat(dst, dst_format); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(dst_format)); + return ret; + } + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutControlAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Add edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, + const InDataAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const AnchorPtr &src, + const AnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutControlAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, + const InControlAnchorPtr &dst) { + if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Remove edge Failed."); + return GRAPH_FAILED; +} + +graphStatus GraphUtils::ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const InDataAnchorPtr &new_dst) { + if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); + return GRAPH_FAILED; +} + +graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const InControlAnchorPtr &new_dst) { + if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( + const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { + GE_CHECK_NOTNULL(src); + GE_CHECK_NOTNULL(dst); + GE_CHECK_NOTNULL(new_node); + + InDataAnchorPtr node_in_anchor = new_node->GetInDataAnchor(0); + GE_CHK_BOOL_RET_STATUS(node_in_anchor != nullptr, GRAPH_FAILED, "this node has not inDataAnchor"); + OutDataAnchorPtr node_out_anchor = new_node->GetOutDataAnchor(0); + GE_CHK_BOOL_RET_STATUS(node_out_anchor != nullptr, GRAPH_FAILED, "this node has not outDataAnchor"); + GE_CHK_STATUS_RET(src->ReplacePeer(dst, node_in_anchor, node_out_anchor), "ReplacePeer Failed"); + return GRAPH_SUCCESS; +} + + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, + const NodePtr &remove_node) { + GE_CHECK_NOTNULL(compute_graph); + if (remove_node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // Check if this node is belong to this compute graph, maybe a little slow + const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); + if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { + GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", + remove_node->GetName().c_str(), + compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + // Find all subgraph of this node + const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); + std::vector subgraphs; + std::vector all_nodes; + std::deque candidates; + NodePtr remove_node_new = remove_node; + candidates.emplace_back(remove_node_new); + while (!candidates.empty()) { + const NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = root_graph->GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } + } + } + // Remove all subgraph + for (const auto &remove_graph : subgraphs) { + if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", + remove_node->GetName().c_str(), compute_graph->GetName().c_str()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { + GE_CHECK_NOTNULL(compute_graph); + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); + return GRAPH_FAILED; + } + + // If the node save as input node, delete it + (void)compute_graph->RemoveInputNode(node); + + // If the node save as output node, delete it + (void)compute_graph->RemoveOutputNode(node); + + // If the node has sub-graphs, delete them + auto ret = RemoveSubgraphRecursively(compute_graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); + return GRAPH_FAILED; + } + + auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); + if (iter != compute_graph->nodes_.end()) { + compute_graph->nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +/// Add two edges to the new node, respectively connecting the SRC and DST +/// associated with the original edge +/// A ---> B transfered to A ---> N ---> B +graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, + const std::vector &vec_op_desc) { + GE_CHECK_NOTNULL(in_data_anchor); + for (const auto &op_desc : vec_op_desc) { + GE_CHECK_NOTNULL(op_desc); + + auto ret = op_desc->AddInputDesc(GeTensorDesc()); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); + ret = op_desc->AddOutputDesc(GeTensorDesc()); + GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); + auto node_to_insert = compute_graph.AddNode(op_desc); + + GE_CHECK_NOTNULL(node_to_insert); + GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); + + auto src = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); + if (!src) { + GELOGE(GRAPH_FAILED, "src nullptr error."); + return GRAPH_FAILED; + } + + auto src_out_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); + + auto dst = in_data_anchor->GetOwnerNode(); + if (!dst) { + GELOGE(GRAPH_FAILED, "dst nullptr error."); + return GRAPH_FAILED; + } + + auto dst_in_index = in_data_anchor->GetIdx(); + + auto in_data_anchor_src_format = AnchorUtils::GetFormat(in_data_anchor->GetPeerOutAnchor()); + auto in_data_anchor_dst_format = AnchorUtils::GetFormat(in_data_anchor); + + GE_CHECK_NOTNULL(src->GetOutDataAnchor(src_out_index)); + GE_CHECK_NOTNULL(dst->GetInDataAnchor(dst_in_index)); + + ret = GraphUtils::RemoveEdge(src->GetOutDataAnchor(src_out_index), dst->GetInDataAnchor(dst_in_index)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Remove edge failed"); + return GRAPH_FAILED; + } + + GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)); + GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)); + + ret = GraphUtils::AddEdge(src->GetOutDataAnchor(src_out_index), node_to_insert->GetInDataAnchor(0)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return ret; + } + ret = GraphUtils::AddEdge(node_to_insert->GetOutDataAnchor(0), dst->GetInDataAnchor(dst_in_index)); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return ret; + } + + if (op_desc->HasAttr("input_format")) { + int64_t input_format = 0; + int64_t output_format = 0; + if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { + GELOGW("get attr input_format failed"); + continue; + } + if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { + GELOGW("get attr output_format failed"); + continue; + } + + GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); + GE_CHK_BOOL_RET_STATUS(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().empty(), GRAPH_FAILED, + "Vistor is empty"); + GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)); + + auto status = + AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor(), in_data_anchor_src_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_src_format)); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0), static_cast(input_format)); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", input_format); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0), static_cast(output_format)); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", output_format); + return status; + } + status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0), + in_data_anchor_dst_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_dst_format)); + return status; + } + } + std::vector original_nodes; + GraphUtils::RecordOriginalNames(original_nodes, node_to_insert); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( + ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector &vec_op_desc) { + GE_CHECK_NOTNULL(compute_graph); + GE_CHECK_NOTNULL(in_data_anchor); + graphStatus ret = + ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; + return ret; +} + +/// +/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst +/// @param [in] src +/// @param [in] dsts +/// @param [in] insert_node +/// @param [in] input_index +/// @param [in] output_index +/// @return graphStatus +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, + const std::vector &dsts, const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { + GE_CHECK_NOTNULL(src); + GE_CHECK_NOTNULL(insert_node); + + NodePtr src_node = src->GetOwnerNode(); + if (src_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { + GELOGE(GRAPH_FAILED, "src:%s and insert_node:%s not exist in the same graph.", + src_node->GetName().c_str(), insert_node->GetName().c_str()); + return GRAPH_FAILED; + } + + if (AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddEdge %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); + return GRAPH_FAILED; + } + + OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(src_out_ctrl_anchor); + + bool ctrl_edge_flag = true; + std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); + if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { + ctrl_edge_flag = false; + } + + for (auto &dst : dsts) { + GE_CHECK_NOTNULL(dst); + NodePtr dst_node = dst->GetOwnerNode(); + GELOGI("Insert node %s between %s->%s.", + insert_node->GetName().c_str(), src_node->GetName().c_str(), dst_node->GetName().c_str()); + if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { + GELOGE(GRAPH_FAILED, "src:%s and dst:%s not exist in the same graph.", + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return GRAPH_FAILED; + } + + (void)RemoveEdge(src, dst); + if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), + dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); + return GRAPH_FAILED; + } + + if (!ctrl_edge_flag) { continue; } + for (const InControlAnchorPtr& peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { + if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || + (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { + GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", + src_node->GetName().c_str(), peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), + insert_node->GetName().c_str(), peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraph &compute_graph, + const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + return GRAPH_FAILED; + } + auto iter = find(compute_graph.nodes_.begin(), compute_graph.nodes_.end(), node); + if (iter != compute_graph.nodes_.end()) { + compute_graph.nodes_.erase(iter); + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraphPtr compute_graph, + const NodePtr &node) { + GE_CHECK_NOTNULL(compute_graph); + GE_CHECK_NOTNULL(node); + graphStatus ret = (RemoveJustNode(*compute_graph, node) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED); + return ret; +} + +void GraphUtils::RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); + std::vector original_names; + for (const auto &node_tmp : original_nodes) { + std::vector names_tmp; + ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); + if (opdesc_tmp == nullptr) { + GELOGE(GRAPH_FAILED, "Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); + continue; + } + auto ret = ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); + if (!ret) { + GELOGW("Get list str failed"); + continue; + } + if (names_tmp.size() != 0) { + original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); + } else { + original_names.push_back(opdesc_tmp->GetName()); + } + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), + return, "Set original_op_names fail."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNames(std::vector names_tmp, + const ge::NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); + std::vector original_names; + if (names_tmp.size() != 0) { + original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); + } else { + std::string tmp; + original_names.push_back(tmp); + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), + return, "Set original_op_names fail."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { + char dump_level[MMPA_MAX_PATH ] = { 0x00 }; + INT32 res = mmGetEnv(kDumpGraphLevel, dump_level, MMPA_MAX_PATH); + int64_t dump_graph_level = (res == EN_OK) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; + + if (dump_graph_level == kDumpLevel1) { + return false; + } + + if (dump_graph_level == kDumpLevel2 && ((suffix.find(kDumpStrPartition) != std::string::npos) || + (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || + (suffix.find(kDumpStrAicpu) != std::string::npos) || + (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { + return true; + } + + if (dump_graph_level == kDumpLevel3 && suffix.compare(kDumpStrBuild) != 0) { + return true; + } + + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, + const std::string &suffix, + bool is_always_dump, + const std::string &user_graph_name) { +#ifdef FMK_SUPPORT_DUMP + char dump_ge_graph[MMPA_MAX_PATH] = { 0x00 }; + INT32 res = mmGetEnv(kDumpGeGraph, dump_ge_graph, MMPA_MAX_PATH); + GE_IF_BOOL_EXEC(res != EN_OK && !is_always_dump, return;); + + // dump the graph according to different graph level + if (GraphUtils::MatchDumpStr(suffix)) { + return; + } + + // file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump om txt: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + char *dump_graph_path = std::getenv(kDumpGraphPath); + if (dump_graph_path != nullptr) { + std::string dump_graph_path_str(dump_graph_path); + stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/"); + } + stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; + stream_file_name << "_" << suffix << ".txt"; + std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; + + // Create buffer + ge::Model model("", ""); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(graph))); + Buffer buffer; + const int64_t kDumpLevel = + (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; + model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL && !is_always_dump); + + // Write file + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + GELOGE(GRAPH_FAILED, "parse from string failed."); + return; + } + char real_path[MMPA_MAX_PATH] = {0x00}; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= MMPA_MAX_PATH, return, "file path is too longer!"); + GE_IF_BOOL_EXEC(mmRealPath(proto_file.c_str(), real_path, MMPA_MAX_PATH) != EN_OK, + GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); + + GraphUtils::WriteProtoToTextFile(ge_proto, real_path); + } +#else + GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGrph(const ge::ComputeGraphPtr &graph, + const std::string &path, + const std::string &suffix) { + // file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump om txt: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + stream_file_name << path.c_str() << "/ge_proto_" << std::setw(5) << std::setfill('0') + << file_index; + stream_file_name << "_" << suffix << ".txt"; + std::string proto_file = stream_file_name.str(); + + // Create buffer + ge::Model model("", ""); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(graph))); + Buffer buffer; + const int64_t kDumpLevel = ge::OnnxUtils::NO_DUMP; + model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); + + // Write file + ge::proto::ModelDef ge_proto; + if (buffer.GetData() != nullptr) { + std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); + if (!ge_proto.ParseFromString(str)) { + GELOGE(GRAPH_FAILED, "parse from string failed."); + return; + } + char real_path[MMPA_MAX_PATH] = {0x00}; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= MMPA_MAX_PATH, return, "file path is too longer!"); + GE_IF_BOOL_EXEC(mmRealPath(proto_file.c_str(), real_path, MMPA_MAX_PATH) != EN_OK, + GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); + + GraphUtils::WriteProtoToTextFile(ge_proto, real_path); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, + ge::ComputeGraph &compute_graph) { + ge::proto::ModelDef model_def; + // Get ModelDef object from file generated by DumpGEGraph() + if (!ReadProtoFromTextFile(file, &model_def)) { + GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); + return false; + } + ge::Model model; + // Get Model object from ModelDef by deserialize ModelDef + if (model.Load(model_def) == GRAPH_SUCCESS) { + GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, + "Get computer graph is nullptr"); + compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); + return true; + } else { + GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); + return false; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, + ge::ComputeGraphPtr &compute_graph) { + ge::proto::ModelDef model_def; + // Get ModelDef object from file generated by DumpGEGraph() + if (!ReadProtoFromTextFile(file, &model_def)) { + GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); + return false; + } + ge::Model model; + // Get Model object from ModelDef by deserialize ModelDef + if (model.Load(model_def) == GRAPH_SUCCESS) { + GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, + "Get computer graph is nullptr"); + compute_graph = GraphUtils::GetComputeGraph(model.GetGraph()); + for (const auto &node : compute_graph->GetDirectNode()) { + GELOGI("Node %s set owner graph", node->GetName().c_str()); + GE_CHECK_NOTNULL(node); + if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node %s set owner graph failed", node->GetName().c_str()); + return false; + } + } + return true; + } else { + GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); + return false; + } +} + +// Printing protocol messages in text format is useful for debugging and human editing of messages. +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( + const google::protobuf::Message &proto, const char *real_path) { +#ifdef FMK_SUPPORT_DUMP + const int FILE_AUTHORITY = 0600; + int fd = mmOpen2(real_path, M_WRONLY | M_CREAT | O_TRUNC, FILE_AUTHORITY); + if (fd < 0) { + GELOGE(GRAPH_FAILED, "fail to open the file: %s, %s", real_path, strerror(errno)); + return; + } + google::protobuf::io::FileOutputStream *output = new (std::nothrow) FileOutputStream(fd); + if (output == nullptr) { + GELOGE(GRAPH_FAILED, "Output is nullptr"); + if (mmClose(fd) != 0) { + GELOGE(GRAPH_FAILED, "Close fileoutputstream failed"); + } + return; + } + bool ret = google::protobuf::TextFormat::Print(proto, output); + if (!ret) { + GELOGE(GRAPH_FAILED, "Fail to write the file: %s", real_path); + delete output; + output = nullptr; + GE_CHK_BOOL_EXEC(mmClose(fd) == 0, return, "Close fileoutputstream failed"); + return; + } + delete output; + output = nullptr; + GE_CHK_BOOL_EXEC(mmClose(fd) == 0, return, "Close fileoutputstream failed"); + + FILE *file = fopen(real_path, "rb"); + if (file == nullptr) { + return; + } + if (fseek(file, 0L, SEEK_END) == 0) { + long fileSize = ftell(file); + thread_local long max_dump_file_size = 0; + if (max_dump_file_size == 0) { + string opt = "0"; + // Can not check return value + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); + max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_size != 0 && fileSize != -1 && fileSize > max_dump_file_size) { + GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", max_dump_file_size); + GE_IF_BOOL_EXEC(remove(real_path) != 0, GELOGW("remove %s failed", real_path)); + GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); + return; + } + } + GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); +#else + GELOGW("Need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( + const char *file, google::protobuf::Message *proto) { + if (file == nullptr || proto == nullptr) { + GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); + return false; + } + std::ifstream fs(file, std::ifstream::in); + if (!fs.is_open()) { + GELOGE(GRAPH_FAILED, "proto file '%s' open fail.", file); + return false; + } + google::protobuf::io::IstreamInputStream input(&fs); + bool ret = google::protobuf::TextFormat::Parse(&input, proto); + if (!ret) { + GELOGE(GRAPH_FAILED, "parse proto from text ret fail, please check your text file '%s'.", file); + } + fs.close(); + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, + const std::string &suffix) { +#ifdef FMK_SUPPORT_DUMP + char dump_ge_graph[MMPA_MAX_PATH] = { 0x00 }; + INT32 res = mmGetEnv(kDumpGeGraph, dump_ge_graph, MMPA_MAX_PATH); + int64_t dump_ge_graph_level = + (res == EN_OK) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; + if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { + GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); + return; + } + + // dump the graph according to different graph level + if (GraphUtils::MatchDumpStr(suffix)) { + return; + } + + // 1.Get ge::onnx::ModelProto from ge::Model + ge::Model model("GE", ""); + std::shared_ptr compute_graph_ptr = ComGraphMakeShared(compute_graph); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(compute_graph_ptr))); + onnx::ModelProto model_proto; + if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { + GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); + return; + } + + // 2.Set file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump ge onnx file: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + char *dump_graph_path = std::getenv(kDumpGraphPath); + if (dump_graph_path != nullptr) { + std::string dump_graph_path_str(dump_graph_path); + stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/"); + } + stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; + stream_file_name << "_graph_" << compute_graph.GetGraphID(); + stream_file_name << "_" << suffix << ".pbtxt"; + std::string proto_file = stream_file_name.str(); + if ((proto_file.length()) >= kNameMax) { + GELOGE(GRAPH_FAILED, "File name is too longer!"); + return; + } + std::unique_ptr real_path(new (std::nothrow) char[MMPA_MAX_PATH]{0}); + if (real_path == nullptr) { + GELOGE(GRAPH_FAILED, "New real_path failed."); + return; + } + /// Returning nullptr means 3 case as follows: + /// a.path is MMPA_MAX_PATH chars or more + /// b.the file does not exist + /// c.the path has no permissions + /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() + if (mmRealPath(proto_file.c_str(), real_path.get(), MMPA_MAX_PATH) != EN_OK) { + // For case a + int err_num = errno; + // linux: ENAMETOOLONG windows: ERANGE + if (err_num == ENAMETOOLONG || err_num == ERANGE) { + GELOGE(GRAPH_FAILED, "Call realpath failed: path is MMPA_MAX_PATH chars or more."); + return; + } + } + + // 3. Serialize to file in current path + GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); +#else + GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); +#endif +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, + const std::string &path, + const std::string &suffix) { + // 1.Get ge::onnx::ModelProto from ge::Model + ge::Model model("GE", ""); + std::shared_ptr compute_graph_ptr = ComGraphMakeShared(compute_graph); + model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(compute_graph_ptr))); + onnx::ModelProto model_proto; + if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { + GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); + return; + } + + // 2.Set file name + static std::atomic_long atomic_file_index(0); + auto file_index = atomic_file_index.fetch_add(1); + GELOGD("Start to dump ge onnx file: %ld", file_index); + + thread_local long max_dump_file_num = 0; + if (max_dump_file_num == 0) { + string opt = "0"; + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); + max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); + } + if (max_dump_file_num != 0 && file_index > max_dump_file_num) { + GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); + return; + } + + std::stringstream stream_file_name; + stream_file_name << path.c_str() << "/ge_onnx_" << std::setw(5) << std::setfill('0') << file_index; + stream_file_name << "_graph_" << compute_graph.GetGraphID(); + stream_file_name << "_" << suffix << ".pbtxt"; + std::string proto_file = stream_file_name.str(); + if ((proto_file.length()) >= kNameMax) { + GELOGE(GRAPH_FAILED, "File name is too longer!"); + return; + } + std::unique_ptr real_path(new (std::nothrow) char[MMPA_MAX_PATH]{0}); + if (real_path == nullptr) { + GELOGE(GRAPH_FAILED, "New real_path failed."); + return; + } + /// Returning nullptr means 3 case as follows: + /// a.path is PATH_MAX chars or more + /// b.the file does not exist + /// c.the path has no permissions + /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() + if (mmRealPath(proto_file.c_str(), real_path.get(), MMPA_MAX_PATH) != EN_OK) { + // For case a + if (errno == ENAMETOOLONG) { + GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more."); + return; + } + } + + // 3. Serialize to file in current path + GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, + ge::ComputeGraph &compute_graph) { + if (file == nullptr) { + GELOGE(GRAPH_FAILED, "incorrect parameter. file path is invalid"); + return false; + } + onnx::ModelProto model_proto; + // 1. Get ModelDef object from file generated by DumpGEGraphToOnnx() + if (!ReadProtoFromTextFile(file, &model_proto)) { + GELOGE(GRAPH_FAILED, "Get ModelDef from file failed"); + return false; + } + // 2.Convert onnx::ModelProto To ge::Model + ge::Model model; + if (!OnnxUtils::ConvertModelProtoToGeModel(model_proto, model)) { + GELOGE(GRAPH_FAILED, "Convert ModelDef to Model failed"); + return false; + } + auto compute_graph_ptr = GraphUtils::GetComputeGraph(model.GetGraph()); + if (compute_graph_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Get compute graph from Model failed"); + return false; + } + compute_graph = *(compute_graph_ptr); + return true; +} + +namespace { +using InNodesToOut = std::unordered_map>; + +inline std::string GetNodeNameByAnchor(const Anchor *anchor) { + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Anchor is nullptr"); + return "Null"; + } + auto node = anchor->GetOwnerNode(); + return node == nullptr ? "Null" : node->GetName(); +} + +graphStatus ReplaceOutDataAnchor(const OutDataAnchorPtr &new_anchor, const OutDataAnchorPtr &old_anchor, + InNodesToOut *in_nodes_to_out = nullptr) { + if (new_anchor == nullptr || old_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "new_anchor or old_anchor is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto new_node = new_anchor->GetOwnerNode(); + for (const auto &peer_in_anchor : old_anchor->GetPeerInDataAnchors()) { + auto ret = peer_in_anchor->Unlink(old_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to unlink old anchor link from %s(%d) to %s(%d)", + GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + ret = peer_in_anchor->LinkFrom(new_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to relink new anchors from %s(%d) to %s(%d)", + GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + + if (in_nodes_to_out != nullptr) { + (*in_nodes_to_out)[new_node].insert(peer_in_anchor->GetOwnerNode()); + } + } + return GRAPH_SUCCESS; +} + +graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, InNodesToOut &in_nodes_to_out) { + GE_CHECK_NOTNULL(node); + auto in_data_anchors = node->GetAllInDataAnchors(); + auto out_data_anchors = node->GetAllOutDataAnchors(); + if (out_data_anchors.size() < io_map.size()) { + GELOGE(GRAPH_FAILED, "The io_map specified for node %s type %s is larger %zu than the actual size %zu", + node->GetName().c_str(), node->GetType().c_str(), io_map.size(), out_data_anchors.size()); + return GRAPH_PARAM_INVALID; + } + + for (size_t i = 0; i < out_data_anchors.size(); ++i) { + auto out_data_anchor = out_data_anchors.at(i); + if (out_data_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to relink for node %s type %s, the out data anchor at index %zu is null", + node->GetName().c_str(), node->GetType().c_str(), i); + return GRAPH_FAILED; + } + + int in_index = -1; + if (i < io_map.size()) { + in_index = io_map.at(i); + } + if (in_index < 0) { + out_data_anchor->UnlinkAll(); + continue; + } + + if (in_index >= static_cast(in_data_anchors.size())) { + GELOGE(GRAPH_PARAM_INVALID, "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", + node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_anchors.size()); + return GRAPH_PARAM_INVALID; + } + auto in_anchor = in_data_anchors.at(in_index); + if (in_anchor == nullptr) { + GELOGW("Invalid in data anchors(null) found at node %s type %s index %d, ignore it.", node->GetName().c_str(), + node->GetType().c_str(), in_index); + continue; + } + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; + } + if (peer_out_anchor->Unlink(in_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed relink node %s type %s, failed to unlink the data link" + " from %s(%d) to it at input-index %d", + node->GetName().c_str(), node->GetType().c_str(), GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), + peer_out_anchor->GetIdx(), in_index); + return GRAPH_FAILED; + } + auto ret = ReplaceOutDataAnchor(peer_out_anchor, out_data_anchor, &in_nodes_to_out); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to relink node %s type %s for relinking data anchors", node->GetName().c_str(), + node->GetType().c_str()); + return GRAPH_FAILED; + } + } + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + in_anchor->UnlinkAll(); + } + return GRAPH_SUCCESS; +} + +InNodesToOut GetFullConnectIONodes(const NodePtr &node) { + InNodesToOut in_nodes_to_out; + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return in_nodes_to_out; + } + auto in_nodes_list = node->GetInNodes(); + auto out_nodes_list = node->GetOutNodes(); + auto out_nodes = std::unordered_set(out_nodes_list.begin(), out_nodes_list.end()); + + for (const auto &in_node : in_nodes_list) { + in_nodes_to_out.insert(std::make_pair(in_node, out_nodes)); + } + return in_nodes_to_out; +} + +graphStatus RelinkControlNodeIfNeed(const NodePtr &node, InNodesToOut &in_nodes_to_out, + InNodesToOut &connected_data_in_to_out) { + GE_CHECK_NOTNULL(node); + for (const auto &in_node_to_out : in_nodes_to_out) { + auto &in_node = in_node_to_out.first; + GE_CHECK_NOTNULL(in_node); + auto &connected_data_out = connected_data_in_to_out[in_node]; + for (const auto &out_node : in_node_to_out.second) { + GE_CHECK_NOTNULL(out_node); + if (connected_data_out.count(out_node) == 0) { + GE_CHECK_NOTNULL(in_node->GetOutControlAnchor()); + if (in_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor())) { + continue; + } + auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), out_node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when isolating node %s type %s", + in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), + node->GetType().c_str()); + return GRAPH_FAILED; + } + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus ReplaceOutDataAnchors(const Node::Vistor &new_outs, + const Node::Vistor &old_outs, const std::vector &outputs_map) { + auto new_out_size = new_outs.size(); + if (new_out_size < outputs_map.size()) { + GELOGE(GRAPH_PARAM_INVALID, + "Failed to replace out data anchors, the actual size %zu is less than the mapping size %zu", new_out_size, + outputs_map.size()); + return GRAPH_PARAM_INVALID; + } + for (size_t i = 0; i < new_out_size; ++i) { + auto &new_out_anchor = new_outs.at(i); + if (new_out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on new node is null, index %zu", i); + return GRAPH_FAILED; + } + if (i >= outputs_map.size()) { + continue; + } + auto old_index = outputs_map.at(i); + if (old_index < 0) { + continue; + } + + const OutDataAnchorPtr &old_out_anchor = old_outs.at(old_index); + if (old_out_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on old node is null, index %d", + old_index); + return GRAPH_FAILED; + } + auto ret = ReplaceOutDataAnchor(new_out_anchor, old_out_anchor); + if (ret != GRAPH_SUCCESS) { + return ret; + } + } + + return GRAPH_SUCCESS; +} + +graphStatus ReplaceInDataAnchors(const Node::Vistor &new_ins, + const Node::Vistor &old_ins, const std::vector &inputs_map) { + auto new_in_size = new_ins.size(); + if (new_in_size < inputs_map.size()) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the actual size %zu is less than the mapping size %zu", + new_in_size, inputs_map.size()); + return GRAPH_PARAM_INVALID; + } + + for (size_t i = 0; i < new_in_size; ++i) { + auto &new_in_anchor = new_ins.at(i); + if (new_in_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on new node is null, index %zu", i); + return GRAPH_FAILED; + } + if (i >= inputs_map.size()) { + continue; + } + auto old_index = inputs_map.at(i); + if (old_index < 0) { + continue; + } + const InDataAnchorPtr &old_in_anchor = old_ins.at(old_index); + if (old_in_anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on old node is null, index %d", + old_index); + return GRAPH_FAILED; + } + + auto peer_out_anchor = old_in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + GELOGW("Peer out anchor is nullptr"); + continue; + } + auto ret = peer_out_anchor->Unlink(old_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", + GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), + GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + ret = peer_out_anchor->LinkTo(new_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to link new anchors, link from %s(%d) to %s(%d)", + GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), + GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_node) { + GE_CHECK_NOTNULL(new_node); + GE_CHECK_NOTNULL(new_node->GetInControlAnchor()); + GE_CHECK_NOTNULL(old_node); + GE_CHECK_NOTNULL(old_node->GetInControlAnchor()); + auto peer_out_anchors = old_node->GetInControlAnchor()->GetPeerAnchors(); + auto new_in_control_anchor = new_node->GetInControlAnchor(); + auto exists_out_anchors = new_in_control_anchor->GetPeerAnchors(); + auto exists_out_anchors_set = std::set(exists_out_anchors.begin(), exists_out_anchors.end()); + for (const auto &peer_out_anchor : peer_out_anchors) { + if (peer_out_anchor != nullptr) { + if (exists_out_anchors_set.count(peer_out_anchor) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(peer_out_anchor, new_in_control_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return GRAPH_FAILED; + } + } else { + GELOGW("peer outanchor is nullptr"); + continue; + } + } + auto old_out_control_anchor = old_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(old_out_control_anchor); + auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); + auto new_out_control_anchor = new_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(new_out_control_anchor); + auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); + auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); + for (const auto &peer_in_anchor : peer_in_anchors) { + if (peer_in_anchor != nullptr) { + if (exists_in_anchors_set.count(peer_in_anchor) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(new_out_control_anchor, peer_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add edge failed"); + return GRAPH_FAILED; + } + } else { + GELOGW("Peer inanchor is nullptr"); + continue; + } + } + + return GRAPH_SUCCESS; +} +} // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, + const std::vector &io_map) { + if (node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Failed to isolate node(null)"); + return GRAPH_PARAM_INVALID; + } + + /// We must get full connections info before re-link data io, because the data + /// edges may be unlinked when relink data io + auto in_nodes_to_out = GetFullConnectIONodes(node); + + InNodesToOut data_in_to_out; + auto ret = RelinkDataIO(node, io_map, data_in_to_out); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to isolate node %s type %s when relink data IO", node->GetName().c_str(), + node->GetType().c_str()); + return ret; + } + + ret = RelinkControlNodeIfNeed(node, in_nodes_to_out, data_in_to_out); + if (ret != GRAPH_SUCCESS) { + return ret; + } + NodeUtils::UnlinkAll(*node); + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::IsolateNode(const NodePtr &node, const std::initializer_list &io_map) { + return IsolateNode(node, std::vector(io_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNodeOneIO(const NodePtr &node) { + if (node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "incorrect parameter. node is invalid"); + return GRAPH_PARAM_INVALID; + } + if (node->GetAllInDataAnchorsSize() != 1) { + return GRAPH_PARAM_INVALID; + } + if (node->GetAllOutDataAnchorsSize() != 1) { + return GRAPH_PARAM_INVALID; + } + return IsolateNode(node, {0}); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, + const std::vector &outputs_map) { + if ((new_node == nullptr) || (old_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto ret = ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); + if (ret != GRAPH_SUCCESS) { + // The error log was printed in `ReplaceNodeDataAnchors` + return GRAPH_FAILED; + } + ret = ReplaceControlAnchors(new_node, old_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace control anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( + const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list inputs_map, + const std::initializer_list outputs_map) { + return ReplaceNodeAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, std::initializer_list outputs_map) { + return ReplaceNodeDataAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, + const std::vector &outputs_map) { + if (new_node == nullptr || old_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + + auto ret = ReplaceOutDataAnchors(new_node->GetAllOutDataAnchors(), old_node->GetAllOutDataAnchors(), outputs_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace out data anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + ret = ReplaceInDataAnchors(new_node->GetAllInDataAnchors(), old_node->GetAllInDataAnchors(), inputs_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, + "Failed to replace in data anchors when replace node from old node %s type %s to new node %s type %s", + old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), + new_node->GetType().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_ctrl_in_nodes = src_node->GetInControlNodes(); + if (src_ctrl_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + + std::unordered_set exist_in_ctrl_nodes_set; + auto exist_in_ctrl_nodes = dst_node->GetInControlNodes(); + if (!exist_in_ctrl_nodes.empty()) { + exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); + } + + auto dst_ctrl = dst_node->GetInControlAnchor(); + for (const auto &in_node : src_ctrl_in_nodes) { + if (exist_in_ctrl_nodes_set.count(in_node) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_ctrl); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", + in_node->GetName().c_str(), dst_node->GetName().c_str(), src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveInCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto ret = CopyInCtrlEdges(src_node, dst_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copy in ctrl edges failed"); + return ret; + } + GE_CHECK_NOTNULL(src_node->GetInControlAnchor()); + src_node->GetInControlAnchor()->UnlinkAll(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto out_ctrl_nodes = src_node->GetOutControlNodes(); + if (out_ctrl_nodes.empty()) { + return GRAPH_SUCCESS; + } + + std::unordered_set exists_out_ctrl_nodes_set; + for (const auto &node : dst_node->GetOutControlNodes()) { + exists_out_ctrl_nodes_set.insert(node.get()); + } + + auto dst_out_ctrl = dst_node->GetOutControlAnchor(); + for (const auto &node : out_ctrl_nodes) { + if (exists_out_ctrl_nodes_set.count(node.get()) > 0) { + continue; + } + auto ret = GraphUtils::AddEdge(dst_out_ctrl, node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", + dst_node->GetName().c_str(), node->GetName().c_str(), src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCtrlEdges(NodePtr &src_node, + NodePtr &dst_node) { + if (src_node == nullptr || dst_node == nullptr) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_FAILED; + } + auto ret = CopyOutCtrlEdges(src_node, dst_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); + return ret; + } + GE_CHECK_NOTNULL(src_node->GetOutControlAnchor()); + src_node->GetOutControlAnchor()->UnlinkAll(); + return GRAPH_SUCCESS; +} + +/// +/// Copy all in-data edges from `src_node` to `dst_node`. +/// @param src_node +/// @param dst_node +/// @return +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_data_in_nodes = src_node->GetInDataNodes(); + if (src_data_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { + auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + auto ret = + GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, + const NodePtr &node) { + if (graph->AddInputNode(node) == nullptr) { + GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); + return GRAPH_FAILED; + } + graph->SetInputSize(graph->GetInputSize() + 1); + graph->inputs_order_.emplace_back(node->GetName()); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +ComputeGraphPtr GraphUtils::FindRootGraph(ComputeGraphPtr graph) { + ComputeGraphPtr result = nullptr; + while (graph != nullptr) { + result = std::move(graph); + graph = result->GetParentGraph(); + } + return result; +} + +/// +/// Make a copy of ComputeGraph. +/// @param graph: original graph. +/// @param prefix: node name prefix of new graph. +/// @param output_nodes: output nodes of new graph. +/// @return ComputeGraphPtr +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +ComputeGraphPtr GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, + std::vector &input_nodes, std::vector &output_nodes) { + GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); + ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); + GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); + + std::unordered_map all_new_nodes; + for (const auto &n : graph->GetDirectNode()) { + OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); + + if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { + return nullptr; + } + + op_desc->SetName(n->GetName() + prefix); + NodePtr node = new_graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + all_new_nodes[node->GetName()] = node; + + if (node->GetType() == DATA) { + input_nodes.emplace_back(node); + } else if (node->GetType() == NETOUTPUT) { + output_nodes.emplace_back(node); + } + } + + for (const auto &n : graph->GetDirectNode()) { + if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { + return nullptr; + } + } + + std::string session_graph_id; + if (AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { + bool ret = AttrUtils::SetStr(*new_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + if (!ret) { + GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_SESSION_GRAPH_ID failed."); + return nullptr; + } + } + + // copy info of output nodes from old graph to new graph. + std::vector> out_nodes_info = graph->GetGraphOutNodesInfo(); + std::vector> new_out_nodes_info; + for (const auto &info : out_nodes_info) { + auto it = all_new_nodes.find(info.first->GetName()); + if (it != all_new_nodes.end()) { + new_out_nodes_info.emplace_back(it->second, info.second); + } + } + new_graph->SetGraphOutNodesInfo(new_out_nodes_info); + return new_graph; +} + +/// +/// Copy tensor attribute to new node. +/// @param [in] dst_node: cloned node. +/// @param [in] src_node: original node. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { + if (dst_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Input param dst node not valid"); + return GRAPH_FAILED; + } + if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input param src node not valid"); + return GRAPH_FAILED; + } + + const auto &src_desc = src_node->GetOpDesc(); + dst_desc->CopyAttrsFrom(*src_desc); + + for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { + auto input_desc = dst_desc->MutableInputDesc(i); + if (input_desc == nullptr) { + continue; + } + input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); + } + + for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { + auto output_desc = dst_desc->MutableOutputDesc(i); + if (output_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Param dst node not valid"); + return GRAPH_FAILED; + } + output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Relink all edges for cloned ComputeGraph. +/// @param [in] node: original node. +/// @param [in] prefix: node name prefix of new node. +/// @param [in] all_nodes: all nodes in new graph. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input node not valid"); + return GRAPH_FAILED; + } + + auto it = all_nodes.find(node->GetName() + prefix); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_node = it->second; + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); + continue; + } + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), + new_node->GetInAnchor(in_anchor->GetIdx())); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + + if (node->GetInControlAnchor() != nullptr) { + for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping of all data_anchors in graph +/// @param [in] graph +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(graph); + for (const auto &node : graph->GetAllNodes()) { + // in_data_anchor + if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + + // out_data_anchor + if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name) { + auto root_graph = FindRootGraph(graph); + if (root_graph == nullptr) { + GE_LOGE("Failed find node %s, null root graph", name.c_str()); + return nullptr; + } + + for (const auto &node : root_graph->GetAllNodes()) { + if (node == nullptr) { + continue; + } + if (node->GetName() == name) { + return node; + } + } + + return nullptr; +} + +/// +/// Get reference-mapping for in_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + + if (NodeUtils::IsSubgraphOutput(node)) { + return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); + } + + if (NodeUtils::IsSubgraphInput(node)) { + return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); + } + + const std::string &type = node->GetType(); + if ((type == MERGE) || (type == STREAMMERGE)) { + return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); + } + + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + const std::string &symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol] = { cur_node_info }; + anchor_to_symbol[symbol] = symbol; + } else { + NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping for out_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); + if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { + continue; + } + + int32_t reuse_in_index = -1; + bool reuse_input_flag = IsRefFromInput(out_data_anchor, reuse_in_index); + if (reuse_input_flag && (node->GetInDataAnchor(reuse_in_index) != nullptr)) { + NodeIndexIO exist_node_info(node, reuse_in_index, kIn); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } else { + if (reuse_input_flag) { + GELOGW("Invalid reuse_input attr on output %d of node %s, please check attr reuse_input and reuse_input_index", + out_data_anchor->GetIdx(), node->GetName().c_str()); + } + const std::string &symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors.emplace(std::make_pair(symbol, std::list{ cur_node_info })); + anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + // Data in subgraph + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { + GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(parent_in_anchor); + OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr) { + // Data has and only has one input + NodeIndexIO cur_node_info(node, 0, kIn); + NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of Merge op +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + std::vector exist_node_infos; + std::vector cur_node_infos; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + std::string next_name; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + ge::NodePtr next_node = FindNodeFromAllNodes(graph, next_name); + GE_CHECK_NOTNULL(next_node); + // NextIteration has and only has one output + peer_out_anchor = next_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(peer_out_anchor); + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); + } + } else { + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); + } + } + + size_t anchor_nums = 0; + NodeIndexIO max_node_index_io(nullptr, 0, kOut); + for (const auto &temp_node_info : exist_node_infos) { + auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); + if (iter1 != anchor_to_symbol.end()) { + const std::string &temp_symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(temp_symbol); + if (iter2 != symbol_to_anchors.end()) { + if (iter2->second.size() > anchor_nums) { + max_node_index_io = temp_node_info; + anchor_nums = iter2->second.size(); + } + } + } + } + + std::string symbol; + for (const auto &temp_node_info : exist_node_infos) { + if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != + GRAPH_SUCCESS) || + symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), + temp_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + } + + auto iter = symbol_to_anchors.find(symbol); + if (iter != symbol_to_anchors.end()) { + for (const auto &temp_node_info : cur_node_infos) { + GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); + iter->second.emplace_back(temp_node_info); + anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle output of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + NodePtr parent_node = owner_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + + GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + continue; + } + GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); + // Union symbol of peer_out_anchor & parent_out_anchor + NodeIndexIO peer_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + NodeIndexIO parent_node_info(parent_node, index, kOut); + std::string symbol; + if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, + symbol) != GRAPH_SUCCESS) || symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", + peer_node_info.ToString().c_str(), parent_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol].emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Union ref-mapping +/// @param [in] exist_node_info1 +/// @param [in] exist_node_info2 +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @param [out] symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol) { + const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + if (symbol1 == symbol2) { + symbol = symbol1; + GELOGI("no need to union."); + return GRAPH_SUCCESS; + } + + auto iter1 = symbol_to_anchors.find(symbol1); + auto iter2 = symbol_to_anchors.find(symbol2); + if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { + GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); + return GRAPH_FAILED; + } + + auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); + auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); + symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); + std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); + for (auto &node_index_io : min_iter->second) { + GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); + max_iter->second.emplace_back(node_index_io); + auto iter = anchor_to_symbol.find(node_index_io.ToString()); + if (iter == anchor_to_symbol.end()) { + GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); + return GRAPH_FAILED; + } + if (iter->second != min_symbol) { + GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", + iter->first.c_str(), min_symbol.c_str(), iter->second.c_str()); + } + iter->second = symbol; + } + + GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); + symbol_to_anchors.erase(min_iter); + return GRAPH_SUCCESS; +} + +/// +/// Update symbol mapping with a new reference pair +/// @param [in] cur_node_info +/// @param [in] exist_node_info +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); + if (iter1 == anchor_to_symbol.end()) { + GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", + exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + const std::string &symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(symbol); + if (iter2 == symbol_to_anchors.end()) { + GE_LOGE("symbol %s not found.", symbol.c_str()); + return GRAPH_FAILED; + } + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + iter2->second.emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + + return GRAPH_SUCCESS; +} + +/// +/// Check if out_data_anchor is reference of input +/// @param [in] out_data_anchor +/// @param [out] reuse_in_index +/// @return bool +/// +bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { + if (out_data_anchor == nullptr) { + GELOGW("out_data_anchor is NULL."); + return false; + } + int32_t output_index = out_data_anchor->GetIdx(); + + // pass-through op + NodePtr node = out_data_anchor->GetOwnerNode(); + const std::string &type = node->GetType(); + const std::set pass_through_set = { NETOUTPUT, WHILE, _WHILE, STATELESSWHILE }; + if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { + reuse_in_index = output_index; + GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); + return true; + } + + // Merge op 0th output + if ((type == MERGE) && (output_index == 0)) { + reuse_in_index = 0; + GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); + return true; + } + + // ref op + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGW("op_desc is NULL."); + return false; + } + bool is_ref = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + if (is_ref) { + const string &output_name = op_desc->GetOutputNameByIndex(output_index); + for (const auto &input_name : op_desc->GetAllInputNames()) { + if (!input_name.empty() && (output_name == input_name)) { + reuse_in_index = op_desc->GetInputIndexByName(input_name); + GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), + output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); + return true; + } + } + } + + // reuse input + auto output_op_desc = op_desc->GetOutputDescPtr(output_index); + bool reuse_input = false; + if (output_op_desc != nullptr) { + if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { + uint32_t reuse_input_index = 0; + if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { + reuse_in_index = static_cast(reuse_input_index); + GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), + output_index, reuse_in_index); + return true; + } + } + } + + return false; +} + +/// +/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs +/// of the graph have UNKNOWN_SHAPE operators or not. +/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following +/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE +/// ROOT graph: A -----> B -----> C +/// K subgraph U +/// | +/// V +/// SUB graph: D --> E --> F +/// K K K +/// @param [in] graph +/// @return bool +/// +bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { + if (graph == nullptr) { + GELOGW("Input graph is nullptr."); + return false; + } + for (const auto &node : graph->GetDirectNode()) { + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGD("Node %s, type %s is unknown shape in graph %s.", + node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str()); + return true; + } + } + GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); + return false; +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder& ComputeGraphBuilder::AddNode(const OpDescPtr &op_desc) { + nodes_.emplace_back(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder& ComputeGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + data_links_.emplace_back(std::make_pair(std::make_pair(src_name, out_anchor_ind), + std::make_pair(dst_name, in_anchor_ind))); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return ComputeGraphBuilder +/// +ComputeGraphBuilder& ComputeGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ctrl_links_.emplace_back(std::make_pair(src_name, dst_name)); + return *this; +} + +/// +/// @brief Build nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { + if (owner_graph_ == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "graph is NULL."; + return; + } + + std::string node_name; + for (auto &op_desc : nodes_) { + if (op_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "op_desc is NULL."; + return; + } + + node_name = op_desc->GetName(); + NodePtr node = owner_graph_->AddNode(op_desc); + if (node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "Add node " + node_name + " failed."; + return; + } + + GELOGD("Add node name:%s, type:%s.", node_name.c_str(), op_desc->GetType().c_str()); + node_names_[node_name] = node; + } + + GELOGD("BuildNodes succ."); +} + +/// +/// @brief Build data-links +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildDataLinks(graphStatus &error_code, std::string &error_msg) { + for (auto &pair : data_links_) { + std::string src_name = pair.first.first; + uint32_t out_ind = pair.first.second; + std::string dst_name = pair.second.first; + uint32_t in_ind = pair.second.second; + std::string log_msg = "Add data-edge "; + log_msg.append(src_name).append(":").append(std::to_string(out_ind)).append("->") + .append(dst_name).append(":").append(std::to_string(in_ind)); + + auto src_iter = node_names_.find(src_name); + auto dst_iter = node_names_.find(dst_name); + if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node not exist in graph."; + return; + } + + NodePtr src_node = node_names_[src_name]; + NodePtr dst_node = node_names_[dst_name]; + if ((src_node == nullptr) || (dst_node == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node is NULL."; + return; + } + + if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_ind), dst_node->GetInDataAnchor(in_ind)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed."; + return; + } + + GELOGD("%s succ.", log_msg.c_str()); + } + + GELOGD("BuildDataLinks succ."); +} + +/// +/// @brief Build ctrl-links +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void ComputeGraphBuilder::BuildCtrlLinks(graphStatus &error_code, std::string &error_msg) { + for (auto &pair : ctrl_links_) { + std::string src_name = pair.first; + std::string dst_name = pair.second; + std::string log_msg = "Add ctrl-edge "; + log_msg.append(src_name).append("->").append(dst_name); + + auto src_iter = node_names_.find(src_name); + auto dst_iter = node_names_.find(dst_name); + if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node not exist in graph."; + return; + } + + NodePtr src_node = node_names_[src_name]; + NodePtr dst_node = node_names_[dst_name]; + if ((src_node == nullptr) || (dst_node == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: node is NULL."; + return; + } + + if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed."; + return; + } + + GELOGD("%s succ.", log_msg.c_str()); + } + + GELOGD("BuildCtrlLinks succ."); +} + +/// @brief Get node with name +/// @param [in] name +/// @return NodePtr +/// +NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { + auto iter = node_names_.find(name); + if (iter == node_names_.end()) { + GE_LOGE("node %s not exist.", name.c_str()); + return nullptr; + } + return iter->second; +} + +/// @brief Get all nodes +/// @return std::vector +/// +std::vector ComputeGraphBuilder::GetAllNodes() { + std::vector nodes; + for (const auto &iter : node_names_) { + nodes.emplace_back(iter.second); + } + return nodes; +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::AddNode(const OpDescPtr &op_desc) { + ComputeGraphBuilder::AddNode(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ComputeGraphBuilder::AddControlLink(src_name, dst_name); + return *this; +} + +/// +/// @brief Set index_th input anchor for graph +/// @param [in] index +/// @param [in] node_names +/// @param [in] anchor_inds +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::SetInput(uint32_t index, const std::vector &node_names, + const std::vector &anchor_inds) { + graph_inputs_[index] = std::make_pair(node_names, anchor_inds); + return *this; +} + +/// +/// @brief Set index_th input of graph as useless +/// @param [in] index +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::SetUselessInput(uint32_t index) { + graph_inputs_[index] = std::make_pair(std::vector(), std::vector()); + return *this; +} + +/// +/// @brief Add output anchor for graph +/// @param [in] owner_node_name +/// @param [in] anchor_ind +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::AddOutput(const std::string &owner_node_name, uint32_t anchor_ind) { + graph_outputs_.emplace_back(std::make_pair(owner_node_name, anchor_ind)); + return *this; +} + +/// +/// @brief Add target for graph +/// @param [in] target_name +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::AddTarget(const std::string &target_name) { + graph_targets_.emplace_back(target_name); + return *this; +} + +/// +/// @brief Set parent-node of graph +/// @param [in] parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::SetParentNode(const NodePtr &parent_node) { + parent_node_ = parent_node; + return *this; +} + +/// +/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node +/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::SetInputMapping(const std::map &input_mapping) { + for (auto &item : input_mapping) { + input_mapping_[item.first] = item.second; + } + return *this; +} + +/// +/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind +/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node +/// @return CompleteGraphBuilder +/// +CompleteGraphBuilder& CompleteGraphBuilder::SetOutputMapping(const std::map &output_mapping) { + for (auto &item : output_mapping) { + output_mapping_[item.first] = item.second; + } + return *this; +} + +/// +/// @brief Build graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return ComputeGraphPtr +/// +ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { + owner_graph_ = shared_ptr(new (std::nothrow) ComputeGraph(name_)); + if (owner_graph_ == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "graph is NULL."; + return nullptr; + } + + BuildNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildDataLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildCtrlLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + AddDataNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + if (retval_flag_) { + AddRetValNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + BuildGraphTargets(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + } else { + AddNetOutputNode(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + } + + PostProcess(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + return owner_graph_; +} + +/// +/// @brief Add data nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { + for (auto &input : graph_inputs_) { + NodePtr data_node = AddDataNode(input.first, error_code, error_msg); + if (data_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + + " failed."; + return; + } + + if (owner_graph_->AddInputNode(data_node) == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + + " failed."; + return; + } + + // useless input + std::vector input_names = input.second.first; + std::vector anchor_indes = input.second.second; + if (input_names.size() != anchor_indes.size()) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; + return; + } + if (input_names.empty()) { + continue; + } + + size_t input_num = input_names.size(); + for (size_t i = 0; i < input_num; i++) { + std::string input_name = input_names[i]; + uint32_t ind = anchor_indes[i]; + auto iter = node_names_.find(input_name); + if (iter == node_names_.end()) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; + return; + } + + NodePtr in_node = node_names_[input_name]; + if (in_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; + return; + } + + if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + + input_name + ":" + std::to_string(ind) + " failed."; + return; + } + } + + GELOGD("AddDataNodes : Add %u input succ.", input.first); + } + + GELOGD("AddDataNodes succ."); +} + +/// +/// @brief Add data node +/// @param [in] index +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { + std::string data_name = "Data_" + std::to_string(index); + OpDescBuilder op_desc_builder(data_name, "Data"); + OpDescPtr op_desc = op_desc_builder.AddInput("x") + .AddOutput("y") + .Build(); + if (op_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; + return nullptr; + } + + auto index_iter = input_mapping_.find(index); + if (index_iter != input_mapping_.end()) { + if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; + return nullptr; + } + } + + NodePtr data_node = owner_graph_->AddNode(op_desc); + if (data_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddDataNode failed: add node " + data_name + " failed."; + return nullptr; + } + node_names_[data_name] = data_node; + + return data_node; +} + +/// +/// @brief Add RetVal nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { + size_t output_num = graph_outputs_.size(); + for (size_t i = 0; i < output_num; i++) { + int32_t index = graph_outputs_[i].second; + auto out_iter = node_names_.find(graph_outputs_[i].first); + if (out_iter == node_names_.end()) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; + return; + } + NodePtr node = out_iter->second; + if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode failed: node is NULL."; + return; + } + + std::string name = node->GetName() + "_RetVal_"+ std::to_string(index); + OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); + if (ret_val_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; + return; + } + ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); + if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || + (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; + return; + } + + if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && + ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; + return; + } + auto iter = output_mapping_.find(i); + if (iter != output_mapping_.end()) { + if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; + return; + } + } + + NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); + if (ret_val_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add node failed."; + return; + } + + if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add data-edge " + + node->GetName() + ":" + std::to_string(index) + "->" + ret_val_node->GetName() + ":0 failed."; + return; + } + } + + GELOGD("AddRetValNodes succ."); +} + +/// +/// @brief Build target-nodes for graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { + std::vector target_nodes; + for (const std::string &target_name : graph_targets_) { + auto target_iter = node_names_.find(target_name); + if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; + return; + } + target_nodes.emplace_back(target_iter->second); + } + owner_graph_->SetGraphTargetNodesInfo(target_nodes); +} + +/// +/// @brief Add NetOutput node +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { + std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT; + OpDescPtr net_output_desc = shared_ptr(new (std::nothrow) OpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT)); + if (net_output_desc == nullptr) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: op_desc is NULL."; + return; + } + + size_t output_num = graph_outputs_.size(); + std::vector peer_out_anchors(output_num); + for (size_t i = 0; i < output_num; i++) { + int32_t index = graph_outputs_[i].second; + auto out_iter = node_names_.find(graph_outputs_[i].first); + if (out_iter == node_names_.end()) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: node " + graph_outputs_[i].first + " not exist in graph."; + return; + } + NodePtr node = out_iter->second; + if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: node is NULL."; + return; + } + + ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); + uint32_t update_index = i; + auto iter = output_mapping_.find(i); + if (iter != output_mapping_.end()) { + update_index = iter->second; + } + if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, update_index)) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: set attr PARENT_NODE_INDEX failed."; + return; + } + if (net_output_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: add input_desc ailed."; + return; + } + peer_out_anchors[i] = node->GetOutDataAnchor(index); + } + + BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return; + } + + GELOGD("%s succ.", log_msg.c_str()); +} + +/// +/// @brief Build NetOutput nodes with data & ctrl edges +/// @param [in] net_output_desc +/// @param [in] peer_out_anchors +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, + const std::vector &peer_out_anchors, + graphStatus &error_code, std::string &error_msg) { + std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT; + NodePtr net_output = owner_graph_->AddNode(net_output_desc); + if (net_output == nullptr) { + error_code = GRAPH_FAILED; + error_msg = log_msg + " failed: add NetOutput node failed."; + return; + } + + size_t output_num = graph_outputs_.size(); + for (size_t i = 0; i < output_num; i++) { + if (GraphUtils::AddEdge(peer_out_anchors[i], net_output->GetInDataAnchor(i)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: add data-edge " + + peer_out_anchors[i]->GetOwnerNode()->GetName() + ":" + std::to_string(peer_out_anchors[i]->GetIdx()) + + "->" + NODE_NAME_NET_OUTPUT + ":" + std::to_string(i) + " failed."; + return; + } + } + for (const std::string &target_name : graph_targets_) { + auto target_iter = node_names_.find(target_name); + if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { + error_code = GRAPH_FAILED; + error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; + return; + } + const auto &target_node = target_iter->second; + if (GraphUtils::AddEdge(target_node->GetOutControlAnchor(), net_output->GetInControlAnchor()) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddNetOutputNode failed: add ctrl-edge " + + target_node->GetName() + "->" + NODE_NAME_NET_OUTPUT + " failed."; + return; + } + } +} + +/// +/// @brief process after build +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void CompleteGraphBuilder::PostProcess(graphStatus &error_code, std::string &error_msg) { + if (parent_node_ != nullptr) { + owner_graph_->SetParentNode(parent_node_); + owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); + // ATTR_NAME_SESSION_GRAPH_ID + std::string graph_id; + if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Get attr session_graph_id failed."; + return; + } + if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Set attr session_graph_id failed."; + return; + } + } + + // refresh node name + for (const NodePtr &node : owner_graph_->GetDirectNode()) { + if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { + continue; + } + node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); + } +} + +/// +/// @brief Add node to graph +/// @param [in] op_desc +/// @return PartialGraphBuilder +/// +PartialGraphBuilder& PartialGraphBuilder::AddNode(const OpDescPtr &op_desc) { + ComputeGraphBuilder::AddNode(op_desc); + return *this; +} + +/// +/// @brief Add data-link among nodes in graph +/// @param [in] src_name +/// @param [in] out_anchor_ind +/// @param [in] dst_name +/// @param [in] in_anchor_ind +/// @return PartialGraphBuilder +/// +PartialGraphBuilder& PartialGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) { + ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); + return *this; +} + +/// +/// @brief Add ctrl-link among nodes in graph +/// @param [in] src_name +/// @param [in] dst_name +/// @return PartialGraphBuilder +/// +PartialGraphBuilder& PartialGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { + ComputeGraphBuilder::AddControlLink(src_name, dst_name); + return *this; +} + +/// +/// @brief Set owner graph +/// @param [in] graph +/// @return PartialGraphBuilder +/// +PartialGraphBuilder& PartialGraphBuilder::SetOwnerGraph(const ComputeGraphPtr &graph) { + owner_graph_ = graph; + return *this; +} + +/// +/// @brief Add exist node +/// @param [in] node +/// @return PartialGraphBuilder +/// +PartialGraphBuilder& PartialGraphBuilder::AddExistNode(const NodePtr &node) { + exist_nodes_.emplace_back(node); + return *this; +} + +/// +/// @brief Build partial graph +/// @param [out] error_code +/// @param [out] error_msg +/// @return ComputeGraphPtr +/// +ComputeGraphPtr PartialGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { + if (owner_graph_ == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "graph is NULL."; + return nullptr; + } + + BuildNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildExistNodes(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildDataLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + BuildCtrlLinks(error_code, error_msg); + if (error_code != GRAPH_SUCCESS) { + return nullptr; + } + + return owner_graph_; +} + +/// +/// @brief Build exist nodes +/// @param [out] error_code +/// @param [out] error_msg +/// @return void +/// +void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string &error_msg) { + std::string node_name; + for (auto &node : exist_nodes_) { + if (node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "Build exist nodes failed: node is NULL."; + return; + } + + node_name = node->GetName(); + if (node->GetOwnerComputeGraph() != owner_graph_) { + error_code = GRAPH_FAILED; + error_msg = "Build exist nodes failed: node " + node_name + " not belongs to this graph."; + return; + } + + GELOGD("Add exist_node name:%s.", node_name.c_str()); + node_names_[node_name] = node; + } + + GELOGD("Build exist nodes succ."); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::TopologicalSortingByName( + const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { + std::vector stack_input; + std::map map_in_edge_num; + graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Sort nodes failed."); + return GRAPH_FAILED; + } + const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; + std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, + [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); + + std::queue stack; + NodePtr cur_node = nullptr; + std::map name_node_map; + vector nodes_name; + while (!stack_input.empty() || !stack.empty()) { + if (!stack.empty()) { + cur_node = stack.front(); + stack.pop(); + } else { + cur_node = stack_input.back(); + stack_input.pop_back(); + } + node_vec.emplace_back(cur_node); + compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); + for (const auto &iter : name_node_map) { + nodes_name.emplace_back(iter.first); + } + std::sort(nodes_name.begin(), nodes_name.end()); + for (const auto &iter : nodes_name) { + stack.push(name_node_map[iter]); + } + name_node_map.clear(); + nodes_name.clear(); + } + // If they are not equal, there is a closed loop + if (node_vec.size() != compute_graph->nodes_.size()) { + std::set itered_nodes_set; + for (auto &node : node_vec) { + itered_nodes_set.insert(node.get()); + } + GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", + compute_graph->nodes_.size(), node_vec.size()); + for (auto &node : compute_graph->nodes_) { + if (itered_nodes_set.count(node.get()) == 0) { + GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); + } + } + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +} // namespace ge diff --git a/metadef/graph/utils/mem_utils.h b/metadef/graph/utils/mem_utils.h new file mode 100644 index 00000000..7e8dd9fd --- /dev/null +++ b/metadef/graph/utils/mem_utils.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ +#define COMMON_GRAPH_UTILS_MEM_UTILS_H_ + +#include +#include + +namespace ge { +template +static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { + typedef typename std::remove_const<_Tp>::type _Tp_nc; + std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); + return ret; +} +} + +#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/metadef/graph/utils/node_utils.cc b/metadef/graph/utils/node_utils.cc new file mode 100644 index 00000000..d3ea1215 --- /dev/null +++ b/metadef/graph/utils/node_utils.cc @@ -0,0 +1,1062 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/graph_utils.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/anchor.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/types.h" +#include "external/graph/operator.h" +#include "graph/ge_context.h" +#include "graph/runtime_inference_context.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/tensor_adapter.h" +#include "graph/utils/type_utils.h" + +namespace ge { +std::map> NodeUtils::map_send_info_{}; +std::map> NodeUtils::map_recv_info_{}; + +const std::set kConstOpTypes = { "Const", "Constant" }; + +const std::set kIfOpTypes = { "If", "_If", "StatelessIf" }; +const std::set kWhileOpTypes = { "While", "_While", "StatelessWhile" }; +const std::set kCaseOpTypes = { "Case" }; +const std::set kForOpTypes = { "For" }; + +bool OpShapeIsUnknown(const OpDescPtr &desc) { + for (const auto &ptr : desc->GetAllInputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + for (const auto &ptr : desc->GetAllOutputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + return false; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, + const uint32_t &event_id) { + GE_CHECK_NOTNULL(node); + map_send_info_[node].push_back(event_id); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, + const uint32_t &event_id) { + GE_CHECK_NOTNULL(node); + map_recv_info_[node].push_back(event_id); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector &vec_send) { + GE_CHECK_NOTNULL(node); + auto find = map_send_info_.find(node); + if (find == map_send_info_.end()) { + return GRAPH_FAILED; + } else { + vec_send = find->second; + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv) { + GE_CHECK_NOTNULL(node); + auto find = map_recv_info_.find(node); + if (find == map_recv_info_.end()) { + return GRAPH_FAILED; + } else { + vec_recv = find->second; + return GRAPH_SUCCESS; + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { + map_send_info_.clear(); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { + map_recv_info_.clear(); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { + GE_CHECK_NOTNULL(src); + NodePtr cur_ptr; + if (depth < 1) { + return GRAPH_FAILED; + } + for (int i = 0; i < depth; i++) { + if (src->GetOutDataNodes().size() != 1) { + return GRAPH_FAILED; + } + cur_ptr = src->GetOutDataNodes().at(0); + GE_CHECK_NOTNULL(cur_ptr); + } + dst = cur_ptr; + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, + InControlAnchorPtr &in_control) { + GE_CHECK_NOTNULL(node_ptr); + for (const auto &p : node_ptr->GetAllOutDataAnchors()) { + GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); + for (const auto &p_in : p->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); + out_data = p; + in_control = p_in; + return GRAPH_SUCCESS; + } + } + return GRAPH_FAILED; +} + +graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, + "node or in_data_anchor is nullptr"); + + bool find_flag = false; + uint32_t index = 0; + vector::iterator it = node_ptr->in_data_anchors_.end(); + for (const auto &tmp : node_ptr->in_data_anchors_) { + if (tmp == in_data_anchor) { + find_flag = true; + auto iter = node_ptr->in_data_anchors_.begin() + index; + if (iter != node_ptr->in_data_anchors_.end()) { + it = node_ptr->in_data_anchors_.erase(iter); + } + break; + } + index++; + } + for (; it != node_ptr->in_data_anchors_.end(); ++it) { + (*it)->SetIdx(index); + index++; + } + + if (!find_flag) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); + GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { + node.anchor_status_updated_ = true; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { + GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); + return IsAnchorStatusSet(*node_ptr); +} + +bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } + +graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { + if ((origin_node == nullptr) || (new_node == nullptr)) { + return GRAPH_FAILED; + } + auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); + auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); + if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { + return GRAPH_FAILED; + } + + for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { + for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, + "unlink peer_anchor failed"); + GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + + for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, + "unlink peer_anchor failed"); + GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + } + + auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(origin_out_control_anchor); + auto new_out_control_anchor = new_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(new_out_control_anchor); + for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { + GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { + GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, + "linkto peer_anchor failed"); + } + origin_out_control_anchor->UnlinkAll(); + + return GRAPH_SUCCESS; +} + +bool NodeUtils::IsConst(const Node &node) { + auto src_node_type = node.GetType(); + bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); + return is_const; +} + +void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "node is null"); + return; + } + UpdateIsInputConst(*node_ptr); +} + +/// +/// update is_input_const +/// @param node +/// @return void +/// +void NodeUtils::UpdateIsInputConst(Node &node) { + std::vector is_input_const; + size_t anchor_num = node.GetAllInDataAnchors().size(); + for (size_t i = 0; i < anchor_num; i++) { + auto in_anchor = node.GetInDataAnchor(static_cast(i)); + if (in_anchor == nullptr) { + is_input_const.push_back(false); + continue; + } + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + is_input_const.push_back(false); + continue; + } + auto src_node = peer_out_anchor->GetOwnerNode(); + if (src_node == nullptr) { + is_input_const.push_back(false); + continue; + } + if (IsConst(*(src_node))) { + is_input_const.push_back(true); + } else { + is_input_const.push_back(false); + } + } + if (node.GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); + return; + } + node.GetOpDesc()->SetIsInputConst(is_input_const); +} + +void NodeUtils::UnlinkAll(const Node &node) { + for (const auto &anchor : node.GetAllOutAnchors()) { + anchor->UnlinkAll(); + } + for (const auto &anchor : node.GetAllInAnchors()) { + anchor->UnlinkAll(); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { + if (node_ptr == nullptr) { + GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + return GRAPH_FAILED; + } + auto op_desc = node_ptr->GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_FAILED; + } + bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (is_unknown_graph) { + return GRAPH_SUCCESS; + } + for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + auto out_dims = output_tensor->GetShape().GetDims(); + auto out_dtype = output_tensor->GetDataType(); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", + node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); + continue; + } + auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); + if (peer_input_desc == nullptr) { + GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); + continue; + } + // check shape and dtype continuity. do not stop process + auto peer_input_dims = peer_input_desc->GetShape().GetDims(); + auto peer_input_dtype = peer_input_desc->GetDataType(); + if (out_dtype != peer_input_dtype) { + GELOGW("current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " + "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), + TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); + } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { + string out_shape_str, peer_in_shape_str; + out_shape_str += "["; + for (int64_t dim : out_dims) { + out_shape_str += std::to_string(dim) + " "; + } + out_shape_str += "]"; + peer_in_shape_str += "["; + for (int64_t dim : peer_input_dims) { + peer_in_shape_str += std::to_string(dim) + " "; + } + peer_in_shape_str += "]"; + + GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " + "input_shape is [%s].The two shape should be same! Please check graph and fix it", + node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), + peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); + } + GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), + output_tensor->GetShape().GetDimNum(), output_tensor->GetDataType(), + output_tensor->GetOriginDataType()); + peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); + peer_input_desc->SetShape(output_tensor->GetShape()); + peer_input_desc->SetDataType(output_tensor->GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); + std::vector> shape_range; + (void) output_tensor->GetShapeRange(shape_range); + peer_input_desc->SetShapeRange(shape_range); + ge::TensorUtils::SetRealDimCnt(*peer_input_desc, + static_cast(output_tensor->GetShape().GetDims().size())); + GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), + peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(), + peer_input_desc->GetOriginDataType()); + } + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + const auto &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { + if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add input desc failed"); + return GRAPH_FAILED; + } + } + + for (size_t i = node->in_data_anchors_.size(); i < num; ++i) { + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); + return GRAPH_FAILED; + } + node->in_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + const auto &op_desc = node->GetOpDesc(); + while (op_desc->GetInputsSize() > num) { + if (!OpDescUtils::ClearInputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } + + auto input_names = op_desc->GetAllInputName(); + (void)op_desc->UpdateInputName(input_names); + auto is_input_const = op_desc->GetIsInputConst(); + is_input_const.resize(num); + op_desc->SetIsInputConst(is_input_const); + + while (node->in_data_anchors_.size() > num) { + node->in_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + const OpDescPtr &op_desc = node->GetOpDesc(); + for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { + if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add output desc failed"); + return GRAPH_FAILED; + } + } + + for (size_t i = node->out_data_anchors_.size(); i < num; ++i) { + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); + return GRAPH_FAILED; + } + node->out_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, uint32_t num) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Input node is null"); + return GRAPH_FAILED; + } + + const auto &op_desc = node->GetOpDesc(); + auto output_names = op_desc->GetAllOutputName(); + while (op_desc->GetOutputsSize() > num) { + if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { + return GRAPH_FAILED; + } + } + (void)op_desc->UpdateOutputName(output_names); + + while (node->out_data_anchors_.size() > num) { + node->out_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + +bool NodeUtils::IsInNodesEmpty(const Node &node) { + for (const auto &in_anchor : node.in_data_anchors_) { + if (in_anchor != nullptr) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor != nullptr) { + if (out_anchor->GetOwnerNode() != nullptr) { + return false; + } + } + } + } + + if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { + auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); + for (const auto &out_control_anchor : peer_out_control_anchors) { + if (out_control_anchor != nullptr) { + if (out_control_anchor->GetOwnerNode() != nullptr) { + return false; + } + } + } + } + + return true; +} +GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GeTensorDesc(); + } + return desc->GetOutputDesc(index); +} +GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GeTensorDesc(); + } + return desc->GetInputDesc(index); +} +graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto output_desc = desc->MutableOutputDesc(index); + if (output_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + output_desc->SetShape(shape); + return GRAPH_SUCCESS; +} +graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { + auto desc = node.GetOpDesc(); + if (desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto input_desc = desc->MutableInputDesc(index); + if (input_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + input_desc->SetShape(shape); + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { + auto desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(desc); + // check self + is_unknow = OpShapeIsUnknown(desc); + if (is_unknow) { + return GRAPH_SUCCESS; + } + auto sub_graph_names = desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + return GRAPH_SUCCESS; + } else { + auto owner_graph = node.GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + for (auto &sub_graph_name : sub_graph_names) { + auto sub_graph = root_graph->GetSubgraph(sub_graph_name); + GE_CHECK_NOTNULL(sub_graph); + for (const auto &node_ptr : sub_graph->GetDirectNode()) { + auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); + if (status != GRAPH_SUCCESS) { + GE_LOGE("get node unknown shape status failed!"); + return status; + } + if (is_unknow) { + return GRAPH_SUCCESS; + } + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus NodeUtils::GetInputConstData(const ConstNodePtr& node_ptr, + const string &dst_name, + GeTensorPtr &ge_tensor) { + GE_CHECK_NOTNULL(node_ptr); + return NodeUtils::GetInputConstData(*node_ptr, dst_name, ge_tensor); +} + +graphStatus NodeUtils::GetInputConstData(const Node &node, + const string &dst_name, + GeTensorPtr &ge_tensor) { + // For inner compute graph + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto index = op_desc->GetInputIndexByName(dst_name); + auto in_data_anchor = node.GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + if (peer_node->GetType() == ENTER || peer_node->GetType() == REFENTER) { + auto enter_in_data_anchor = peer_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(enter_in_data_anchor); + auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(enter_peer_out_data_anchor); + peer_node = enter_peer_out_data_anchor->GetOwnerNode(); + } + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + if (!AttrUtils::MutableTensor(peer_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { + GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) + && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + if (!AttrUtils::MutableTensor(parent_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor)) { + GELOGW("get attr name %s failed.", ATTR_NAME_WEIGHTS.c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; + } + } + // Try get from runtime inference context + auto session_id = std::to_string(GetContext().SessionId()); + RuntimeInferenceContext *runtime_infer_ctx = nullptr; + if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { + GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), + out_data_anchor->GetIdx(), ge_tensor); + if (ret == GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + } + GELOGW("node[%s]'s input[%s]'s peer node is not const", node.GetName().c_str(), dst_name.c_str()); + return GRAPH_FAILED; +} + + +std::string NodeUtils::GetNodeType(const Node &node) { + if (node.GetType() != FRAMEWORKOP) { + return node.GetType(); + } + + std::string type; + (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + return type; +} + +std::string NodeUtils::GetNodeType(const NodePtr &node) { + return node == nullptr ? "" : GetNodeType(*node); +} + +std::vector NodeUtils::GetAllSubgraphs(const Node &node) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to get op desc from node %s ", node.GetName().c_str()); + return {}; + } + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GELOGE(GRAPH_FAILED, "Failed to find root graph from node %s ", node.GetName().c_str()); + return {}; + } + return root_graph->GetAllSubgraphs(); +} + +ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return nullptr; + } + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + return nullptr; + } + return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); +} + +graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { + if (subgraph == nullptr) { + GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); + return GRAPH_PARAM_INVALID; + } + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return GRAPH_PARAM_INVALID; + } + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); + return ret; + } + subgraph->SetParentNode(node.shared_from_this()); + subgraph->SetParentGraph(node.GetOwnerComputeGraph()); + return root_graph->AddSubgraph(subgraph); +} + +/// +/// Check if node is input of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphInput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { + return false; + } + + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + + // dynamic shape unknown graph false + // dynamic shape known graph with functional subgraph maybe true + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } + } + + return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); +} + +/// +/// Check if node is output of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { + return false; + } + + auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); + if (parent_op_desc == nullptr) { + return false; + } + + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { + return false; + } else { + if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { + return false; + } + } + } + + for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { + if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { + return true; + } + } + + return false; +} + +/// +/// @brief Get subgraph original input node. +/// @param [in] node +/// @return Node +/// +NodePtr NodeUtils::GetParentInput(const Node &node) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return nullptr; + } + + // Subgraph Data Node, check for constant input. + const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); + GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + + const NodePtr &parent_node = graph->GetParentNode(); + GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); + + const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); + + const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); + + return peer_out_anchor->GetOwnerNode(); +} + +NodePtr NodeUtils::GetParentInput(const NodePtr &node) { + return node == nullptr ? node : GetParentInput(*node); +} + +/// +/// @brief Get is dynamic shape graph from node. +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsDynamicShape(const Node &node) { + const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (graph == nullptr) { + return false; + } + + bool is_dynamic_shape = false; + (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); + return is_dynamic_shape; +} + +bool NodeUtils::IsDynamicShape(const NodePtr &node) { + return node == nullptr ? false : IsDynamicShape(*node); +} + +/// +/// @brief Check is varying_input for while node +/// @param [in] node: Data node for subgraph +/// @return bool +/// +bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->GetType() != DATA) { + return false; // not input_node for subgraph + } + + const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; // root graph + } + + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + return false; // not input_node for while subgraph + } + + uint32_t index_i = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { + GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); + return false; + } + bool varying_flag = true; + for (const auto &item : node->GetOutDataNodesAndAnchors()) { + if (item.first->GetType() != NETOUTPUT) { + continue; + } + OpDescPtr op_desc = item.first->GetOpDesc(); + uint32_t index_o = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { + continue; // input for while-cond subgraph + } + if (index_i != index_o) { + continue; // varying input for while-body subgraph + } + varying_flag = false; + break; + } + return varying_flag; +} + +/// +/// @brief Get subgraph input is constant. +/// @param [in] node +/// @param [out] string +/// @return bool +/// +bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { + if (node == nullptr) { + return false; + } + + if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { + type = node->GetType(); + return true; + } + + if (node->GetType() != DATA) { + return false; // not subgraph input node + } + + const auto &parent = GetParentInput(node); + return GetConstOpType(parent, type); +} + +/// +/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. +/// @param [in] node +/// @return return GRAPH_SUCCESS if remove successfully, other for failed. +/// +Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + return GRAPH_SUCCESS; + } else { + auto owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(owner_graph); + GE_CHECK_NOTNULL(root_graph); + + std::unordered_set subgraph_to_remove; + for (auto &subgraph_name : subgraph_names) { + std::deque queue; + queue.push_back(subgraph_name); + subgraph_to_remove.insert(subgraph_name); + op_desc->RemoveSubgraphInstanceName(subgraph_name); + while (!queue.empty()) { + auto graph_name = queue.front(); + queue.pop_front(); + + auto subgraph = root_graph->GetSubgraph(graph_name); + GE_CHECK_NOTNULL(subgraph); + for (const auto &sub_node : subgraph->GetDirectNode()) { + auto sub_op_desc = sub_node->GetOpDesc(); + GE_CHECK_NOTNULL(sub_op_desc); + auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); + // Subgraph and all nodes in it will be removed later, + // no need to remove 'SubgraphInstanceName' in op desc here. + for (auto &name : sub_names) { + if (subgraph_to_remove.insert(name).second) { + queue.push_back(name); + } + } + } + } + } + // Remove subgraph from root_graph + for (const auto &name : subgraph_to_remove) { + GELOGI("Remove subgraph:%s.", name.c_str()); + root_graph->RemoveSubgraph(name); + } + } + + return GRAPH_SUCCESS; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { + vector in_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); + return in_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + int parent_index = -1; + if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { + (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); + if (parent_index == index) { + in_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + } + return in_data_node_vec; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { + vector out_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); + return out_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { + out_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + return out_data_node_vec; +} + +NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { + if (node.GetInDataAnchor(index) == nullptr) { + return nullptr; + } + if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { + return nullptr; + } + return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); +} + +vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { + vector> out_data_nodes; + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return out_data_nodes; + } + + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + if (peer_in_anchor == nullptr) { + continue; + } + if (peer_in_anchor->GetOwnerNode() == nullptr) { + continue; + } + out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); + } + return out_data_nodes; +} + +ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { + return oprt.GetNode(); +} + +std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) { + NodePtr input_node = node; + while (input_node != nullptr) { + if (input_node->GetType() != DATA) { + return input_node->GetType(); + } + + auto owner_graph = input_node->GetOwnerComputeGraph(); + auto parent_node = owner_graph->GetParentNode(); + if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0)) { + return node->GetType(); // not in subgraph or while subgraph. + } + + input_node = GetParentInput(input_node); + } + + return ""; +} +} // namespace ge diff --git a/metadef/graph/utils/op_desc_utils.cc b/metadef/graph/utils/op_desc_utils.cc new file mode 100644 index 00000000..dc18f1f2 --- /dev/null +++ b/metadef/graph/utils/op_desc_utils.cc @@ -0,0 +1,851 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/op_desc_utils.h" +#include +#include "debug/ge_attr_define.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "utils/graph_utils.h" +#include "utils/node_utils.h" + +using std::vector; + +/*lint -e512 -e737 -e752*/ +namespace ge { +const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; + +namespace { +const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; +} + +bool OpDescUtils::ClearInputDesc(const NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); + GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); + vector index_list; + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + if (in_anchor->GetPeerOutAnchor() == nullptr) { + index_list.push_back(in_anchor->GetIdx()); + } + } + std::sort(index_list.begin(), index_list.end()); + // Node's in anchor index need shrink + for (size_t i = 0; i < index_list.size(); ++i) { + auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i]; + if (iter < node->GetOpDesc()->inputs_desc_.end()) { + (void)node->GetOpDesc()->inputs_desc_.erase(iter); + } else { + GELOGW("inputs_desc_ iterator out of range."); + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc, + const uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index); + + auto iter = op_desc->inputs_desc_.begin() + index; + if (iter < op_desc->inputs_desc_.end()) { + (void)op_desc->inputs_desc_.erase(iter); + } else { + GELOGW("inputs_desc_ iterator out of range."); + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr"); + return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); +} + +bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); + GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); + vector index_list; + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + if (out_anchor->GetPeerInDataAnchors().empty()) { + index_list.push_back(out_anchor->GetIdx()); + } + } + std::sort(index_list.begin(), index_list.end()); + // Node's out anchor index need shrink + for (size_t i = 0; i < index_list.size(); ++i) { + auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i]; + if (iter < node->GetOpDesc()->outputs_desc_.end()) { + (void)node->GetOpDesc()->outputs_desc_.erase(iter); + } else { + GELOGW("outputs_desc_ iterator out of range."); + } + } + + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, + uint32_t index) { + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index); + + auto iter = op_desc->outputs_desc_.begin() + index; + if (iter < op_desc->outputs_desc_.end()) { + (void)op_desc->outputs_desc_.erase(iter); + } else { + GELOGW("outputs_desc_ iterator out of range."); + } + return true; +} + +bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); + GeAttrValue attr_value; + GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, + "GetQuantizeFactorParams failed"); + return attr_value.GetValue(quant); +} + +graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) { + GeAttrValue attr_value; + GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, + "GetQuantizeFactorParams failed"); + return attr_value.GetValue(quant); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { + GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); + return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 +} + +graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { + return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 +} + +GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { + GeTensorPtr weight = nullptr; + if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) { + GELOGW("MutableTensor error"); + } + + return weight; +} + +GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) { + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "op_desc is null"); + return nullptr; + } + return MutableWeights(*op_desc); +} + +graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { + if (weight == nullptr) { + GELOGE(GRAPH_FAILED, "weight is null"); + return GRAPH_FAILED; + } + return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; +} + +graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(weight); + return SetWeights(*op_desc, weight); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights(const ge::Node &node) { + auto weights = MutableWeights(node); + vector ret(weights.size()); + std::copy(weights.begin(), weights.end(), ret.begin()); + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights( + const ge::ConstNodePtr &node) { + if (node == nullptr) { + return vector(); + } + return GetWeights(*node); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputNode( + const ge::Node &node) { + vector ret; + auto in_anchors = node.GetAllInDataAnchors(); + for (const auto &in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + // normally out_anchor could be null, this is ok + GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); + continue; + } + auto in_node = out_anchor->GetOwnerNode(); + while (true) { + if (in_node == nullptr) { + break; + } + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + ret.push_back(in_node); + break; + } else if (in_node->GetType() == DATA) { + if (NodeUtils::IsWhileVaryingInput(in_node)) { + break; + } + in_node = NodeUtils::GetParentInput(in_node); + } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { + bool is_constant = false; + (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); + if (!is_constant) { + break; + } + // Enter node has and only has one input + if (in_node->GetInDataNodes().size() != 1) { + GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), + in_node->GetInDataNodes().size()); + break; + } + in_node = in_node->GetInDataNodes().at(0); + } else { + break; + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetInputData( + const vector &input_nodes) { + vector ret; + + for (const auto &input_node : input_nodes) { + auto temp_weight = MutableWeights(input_node->GetOpDesc()); + if (temp_weight == nullptr) { + GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); + return vector(); + } + ret.push_back(temp_weight); + } + + return ret; +} +size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { + if (NodeUtils::IsAnchorStatusSet(node)) { + size_t input_num = 0; + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + input_num++; + continue; + } + } + return input_num; // lint !e712 + } else { + GE_IF_BOOL_EXEC( + node.GetInDataNodes().size() < GetConstInputs(node).size(), + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"GetNonConstInputsSize", "InDataNodes size[" + std::to_string(node.GetInDataNodes().size()) + + "] is smaller than ConstInputs[" + std::to_string(GetConstInputs(node).size()) + "]"}); + GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); + return 0); + return node.GetInDataNodes().size() - GetConstInputs(node).size(); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return 0; + } + return GetNonConstInputsSize(*node); +} + +GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) { + GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!"); + size_t i = 0; + if (NodeUtils::IsAnchorStatusSet(node)) { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + if (index_non_const == i) { + return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); + } + ++i; + } + } + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + continue; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + continue; + } + if (owner_node->GetType() == CONSTANT) { + continue; + } + if (index_non_const == i) { + return node.GetOpDesc()->GetInputDesc(anchor->GetIdx()); + } + ++i; + } + } + return GeTensorDesc(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc +OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) { + CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); + return GetNonConstInputTensorDesc(*node, index_non_const); +} + +bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { + bool ret = false; + size_t i = 0; + if (NodeUtils::IsAnchorStatusSet(node)) { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { + if (index_non_const == i) { + index = static_cast(anchor->GetIdx()); + ret = true; + } + ++i; + } + } + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + continue; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + continue; + } + if (owner_node->GetType() == CONSTANT) { + continue; + } + if (index_non_const == i) { + index = static_cast(anchor->GetIdx()); + ret = true; + } + ++i; + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, + size_t index_non_const, + size_t &index) { + CHECK_FALSE_EXEC(node != nullptr, return false); + return GetNonConstInputIndex(*node, index_non_const, index); +} + +bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { + bool ret = false; + if (index < node.GetAllInDataAnchors().size()) { + if (NodeUtils::IsAnchorStatusSet(node)) { + ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); // lint !e712 + } else { + for (const auto &anchor : node.GetAllInDataAnchors()) { + if (anchor->GetIdx() != static_cast(index)) { + continue; + } + auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + break; + } + auto owner_node = peer_anchor->GetOwnerNode(); + if (owner_node == nullptr) { + break; + } + ret = (owner_node->GetType() != CONSTANT); + } + } + } + + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, + size_t index) { + CHECK_FALSE_EXEC(node != nullptr, return false); + return IsNonConstInput(*node, index); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs( + const ge::ConstNodePtr &node) { + if (node == nullptr) { + return vector(); + } + return GetConstInputs(*node); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetNonConstTensorDesc( + const ge::ConstNodePtr &node) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + return vector(); + } + vector ret; + if (NodeUtils::IsAnchorStatusSet(*node)) { + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { + ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); + } + } + } else { + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { + continue; + } + if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { + ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs(const ge::Node &node) { + vector ret; + auto in_anchors = node.GetAllInDataAnchors(); + for (const auto &in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) continue; + + auto in_node = out_anchor->GetOwnerNode(); + if (in_node->GetType() == CONSTANT) { + ret.push_back(in_node); + } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) { + // const --> switch --> matmul + auto switch_input = GetConstInputs(*in_node); + if (switch_input.size() > 0) { + ret.insert(ret.end(), switch_input.begin(), switch_input.end()); + } + } else if (in_node->GetType() == DATA) { + auto parent = NodeUtils::GetParentInput(in_node); + if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { + ret.push_back(parent); + } + } + } + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { + vector ret; + auto op_desc = node.GetOpDesc(); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); + // Place holder operator, try to get the weight from parent node + // when parent node is const operator + if (node.GetType() == PLACEHOLDER) { + std::string parent_op; + (void) AttrUtils::GetStr(op_desc, "parentOpType", parent_op); + // This if judgment is necessary because the current subgraph optimization is multithreaded + // and the parent node of the PLD operation should be a stable type, such as const + if (parent_op == CONSTANT || parent_op == CONSTANTOP) { + NodePtr parent_node = nullptr; + parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); + if (parent_node != nullptr) { + op_desc = parent_node->GetOpDesc(); + GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); + } + } + } + // Const operator, take the weight directly + if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { + auto weight = MutableWeights(op_desc); + if (weight == nullptr) { + GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); + return ret; + } + ret.push_back(weight); + return ret; + } + + if (node.GetType() == DATA) { + auto parent = NodeUtils::GetParentInput(node); + if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { + auto weight = MutableWeights(parent->GetOpDesc()); + if (weight == nullptr) { + GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); + return ret; + } + ret.push_back(weight); + } + return ret; + } + + // Other operators, get weights from connected constop + auto input_nodes = GetConstInputs(node); + for (const auto &input_node : input_nodes) { + auto temp_weight = MutableWeights(input_node->GetOpDesc()); + if (temp_weight == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"MutableWeights", "const op[" + input_node->GetName() + "]'s weight is null"}); + GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); + return vector(); + } + ret.push_back(temp_weight); + } + + return ret; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::NodePtr node) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Node is nullptr"); + return vector(); + } + return MutableWeights(*node); +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::Node &node, const vector &weights) { + GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!"); + if (node.GetOpDesc()->GetType() == CONSTANT) { + if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { + return SetWeights(node.GetOpDesc(), weights[0]); + } + GELOGI("const op weight size %zu should be 1", weights.size()); + return GRAPH_PARAM_INVALID; + } + + auto input_nodes = GetConstInputs(node); + if (weights.size() < input_nodes.size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SetWeights", "weights count can't be less than const input count"}); + GELOGE(GRAPH_FAILED, "weights count can't be less than const input count"); + return GRAPH_PARAM_INVALID; + } + + ge::GeAttrValue::NAMED_ATTRS named_attrs; + (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); + vector copy_weights; + (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); + + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (input_nodes[i]->GetOpDesc() != nullptr) { + SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]); + } + } + + // If set more weights than constop, need to add constop + for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { + // Use org weight before SetWeights Overwrite + auto const_opdesc = CreateConstOp(copy_weights[i]); + GE_CHECK_NOTNULL(const_opdesc); + + auto owner_graph = node.GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = owner_graph->AddNodeFront(const_opdesc); + GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, + GELOGE(GRAPH_FAILED, "graph add link failed!"); + return GRAPH_FAILED); + std::vector original_nodes; + ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::Node &node, const map &weights_map) { + GE_CHECK_NOTNULL(node.GetOpDesc()); + // 1. node is const + if (node.GetOpDesc()->GetType() == CONSTANT) { + if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { + return SetWeights(node.GetOpDesc(), weights_map.begin()->second); + } + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SetWeights", "const op[" + node.GetName() + "] weight size[" + + std::to_string(weights_map.size()) + "] should be 1"}); + GELOGE(GRAPH_PARAM_INVALID, "const op %s weight size %zu should be 1", node.GetName().c_str(), weights_map.size()); + return GRAPH_PARAM_INVALID; + } + // 2. node is not const + for (const auto &pair:weights_map) { + auto in_data_anchor = node.GetInDataAnchor(pair.first); + if (in_data_anchor != nullptr && in_data_anchor->GetPeerOutAnchor() != nullptr) { + // a. update const input node + auto out_anchor = in_data_anchor->GetPeerOutAnchor(); + auto peer_node = out_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "op %s [%d]'s input node is null", node.GetName().c_str(), pair.first); + return GRAPH_PARAM_INVALID; + } + if (peer_node->GetType() != CONSTANT) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SetWeights", "op[" + node.GetName() + "] [" + std::to_string(pair.first) + + "]'s input node should be const, but real op is " + + peer_node->GetName() + ", type is " + peer_node->GetType()}); + GELOGE(GRAPH_PARAM_INVALID, + " op %s [%d]'s input node should be const, but is %s type:%s ", node.GetName().c_str(), + pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str()); + } + SetWeights(peer_node->GetOpDesc(), pair.second); + } else { + // b. create new const input node + auto const_opdesc = CreateConstOp(pair.second); + GE_CHECK_NOTNULL(const_opdesc); + auto owner_graph = node.GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = owner_graph->AddNodeFront(const_opdesc); + if (node.AddLinkFrom(static_cast(pair.first), const_node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first); + return GRAPH_FAILED; + } + } + } + NodeUtils::UpdateIsInputConst(node); + return GRAPH_SUCCESS; +} + +OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { + GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!"); + shared_ptr const_opdesc = ComGraphMakeShared(); + if (const_opdesc == nullptr) { + GELOGE(GRAPH_FAILED, "failed to make_shared "); + return nullptr; + } + + CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr); + + const_opdesc->SetType(CONSTANT); + + thread_local int64_t const_count = 0; + const_opdesc->SetName("dynamic_const_" + std::to_string(GeLog::GetTid()) + "_" + std::to_string(const_count)); + GELOGI("add const op: %s", const_opdesc->GetName().c_str()); + ++const_count; + + (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); + + GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); + + return const_opdesc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { + GE_CHECK_NOTNULL(in_anchor); + GE_CHECK_NOTNULL(tensor_ptr); + auto const_opdesc = CreateConstOp(tensor_ptr); + GE_CHECK_NOTNULL(const_opdesc); + auto in_node = in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node); + auto owner_graph = in_node->GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); + GE_CHECK_NOTNULL(const_node); + if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed."); + return GRAPH_PARAM_INVALID; + } + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +OpDescUtils::SetWeights(ge::NodePtr node, const vector &weights) { + GE_CHECK_NOTNULL(node); + return SetWeights(*node, weights); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { + GE_CHECK_NOTNULL(node); + auto const_ops = GetConstInputs(node); + auto graph = node->GetOwnerComputeGraph(); + if (graph == nullptr) { + GELOGE(GRAPH_FAILED, "Graph is nullptr"); + return GRAPH_PARAM_INVALID; + } + for (const auto &const_op : const_ops) { + GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed", + const_op->GetName().c_str(), const_op->GetType().c_str()); + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), + "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(), + const_op->GetType().c_str()); + } + return GRAPH_SUCCESS; +} + +/// +/// @brief Add input +/// @param [in] name +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddInput(const std::string &name) { + inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add input +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder& OpDescBuilder::AddInput(const std::string &name, const GeTensorDesc &tensor) { + inputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/// +/// @brief Add dynamic input +/// @param [in] name +/// @param [in] num +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, + uint32_t num) { + for (uint32_t i = 0; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic input +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/// +/// @brief Add output +/// @param [in] name +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name) { + outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add output +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name, const GeTensorDesc &tensor) { + outputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/// +/// @brief Add dynamic output +/// @param [in] name +/// @param [in] num +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, + uint32_t num) { + for (uint32_t i = 0; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic output +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/// +/// @brief Build op_desc +/// @return OpDescPtr +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { + OpDescPtr op_desc = shared_ptr(new (std::nothrow) OpDesc(name_, type_)); + if (op_desc == nullptr) { + GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); + return nullptr; + } + + for (auto &input : inputs_) { + if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add input_desc failed."); + return nullptr; + } + } + + for (auto &output : outputs_) { + if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add output_desc failed."); + return nullptr; + } + } + + return op_desc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus OpDescUtils::SetSubgraphInstanceName(const std::string &subgraph_name, + const std::string &subgraph_instance_name, + OpDescPtr &op_desc) { + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(subgraph_name); + if (iter == subgraph_names_to_index.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SetSubgraphInstanceName", "subgraph name[" + subgraph_name + "] is not exists." + "The op is " + op_desc->GetName() + ", type is " + op_desc->GetType() + + ", subgraph is " + subgraph_instance_name}); + GELOGE(GRAPH_PARAM_INVALID, + "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", + subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), subgraph_name.c_str()); + return GRAPH_PARAM_INVALID; + } + + return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); +} +} // namespace ge +/*lint +e512 +e737 +e752*/ diff --git a/metadef/graph/utils/string_utils.h b/metadef/graph/utils/string_utils.h new file mode 100644 index 00000000..e82afb1a --- /dev/null +++ b/metadef/graph/utils/string_utils.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_ +#define COMMON_GRAPH_UTILS_STRING_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include "securec.h" + +namespace ge { +class StringUtils { + public: + static std::string &Ltrim(std::string &s) { + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); + return s; + } + + static std::string &Rtrim(std::string &s) { + (void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end()); + return s; + } + + /// @ingroup domi_common + /// @brief trim space + static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } + + // split string + static std::vector Split(const std::string &str, char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; + } +}; +} // namespace ge +#endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_ diff --git a/metadef/graph/utils/tensor_utils.cc b/metadef/graph/utils/tensor_utils.cc new file mode 100644 index 00000000..a6c77ab5 --- /dev/null +++ b/metadef/graph/utils/tensor_utils.cc @@ -0,0 +1,424 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/tensor_utils.h" +#include + +#include "debug/ge_log.h" +#include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" +#include "graph/ge_tensor.h" +#include "graph/types.h" +#include "graph/utils/type_utils.h" +#include "mmpa/mmpa_api.h" + +namespace ge { +namespace { +// When nc1hwc0 dim size = 5, calc element count directly. +const uint32_t kNc1hwc0CalcByDimsSize = 5; + +// Unknown shape element num +const int64_t kElementCntUnknownShape = -1; + +// Unknown shape mem size +const int64_t kMemSizeUnknownShape = -1; + +// Nchw and nhwc dim size must be 4 +const uint32_t kDimSize4d = 4; + +// C1HWNCoC0 dim size must be 6 +const uint32_t kDimSizeC1hwncoc0 = 6; + +// Cube size is 16 +const uint32_t kTheCubeSize = 16; + +// Default c0 size equals cube size. +const uint32_t kC0SizeDefault = kTheCubeSize; + +// Size equals int8 cube size is 32 +const uint32_t kC0SizeInt8 = 32; + +// NCHW dim N index +const int32_t kNchwDimIdxN = 0; +// NCHW dim C index +const int32_t kNchwDimIdxC = 1; +// NCHW dim H index +const int32_t kNchwDimIdxH = 2; +// NCHW dim W index +const int32_t kNchwDimIdxW = 3; + +const int kDataMemAlignSize = 32; +const int kNum2 = 2; +} // namespace + +/// +/// Check if a * b overflow. +/// @param a multiplier +/// @param b Multiplicand +/// @return true: overflow +/// false: not overflow +/// +static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { + if (a > 0) { + if (b > 0) { + if (a > (INT64_MAX / b)) { + return true; + } + } else { + if (b < (INT64_MIN / a)) { + return true; + } + } + } else { + if (b > 0) { + if (a < (INT64_MIN / b)) { + return true; + } + } else { + if ((a != 0) && (b < (INT64_MAX / a))) { + return true; + } + } + } + return false; +} + +/// +/// Calculate element num by dims directly. +/// @param dims dim info +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntByDims(const std::vector &dims, int64_t &element_cnt) { + element_cnt = 1; + for (int64_t dim : dims) { + if (CheckMultiplyOverflowInt64(element_cnt, dim)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); + GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); + return GRAPH_FAILED; + } + element_cnt *= dim; + } + return GRAPH_SUCCESS; +} + +/// +/// Calculate fixed dims element num. +/// @param dims dim info +/// @param fixed_dim_size fixed dim size +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfFixedDims(const std::vector &dims, Format format, uint32_t fixed_dim_size, + int64_t &element_cnt) { + if (dims.size() != fixed_dim_size) { + GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", + format, TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); + } + return CalcElementCntByDims(dims, element_cnt); +} + +/// +/// Get dim c0 size by type +/// @param data_type data type +/// @return c0 size +/// +static uint32_t GetDimC0(DataType &data_type) { + bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) || + (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8); + return is_int8_size ? kC0SizeInt8 : kC0SizeDefault; +} + +/// +/// Calculate nc1hwc0 element num. +/// @param dims dim info +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfNc1hwc0(const std::vector &dims, DataType data_type, int64_t &element_cnt) { + // When nc1hwc0 dims size = 5, no need split dim c + if (dims.size() == kNc1hwc0CalcByDimsSize) { + return CalcElementCntByDims(dims, element_cnt); + } else if (dims.size() != kDimSize4d) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcElementCntOfNc1hwc0", "dims.size[" + std::to_string(dims.size()) + "] is not " + + std::to_string(kDimSize4d) + " or " + std::to_string(kNc1hwc0CalcByDimsSize)}); + GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d, + kNc1hwc0CalcByDimsSize); + return GRAPH_FAILED; + } + + auto c0 = static_cast(GetDimC0(data_type)); + // Nc1hwc0 dims is according to nchw, dim c index is 1. + auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); + // Store dims is split c to c1 and c0. + std::vector store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; + return CalcElementCntByDims(store_dims, element_cnt); +} + +/// +/// Calculate FractalZ element num. +/// @param dims dim info +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcElementCntOfFractalZ(const std::vector &dims, DataType data_type, + int64_t &element_cnt) { + static char parser_priority[MMPA_MAX_PATH] = { 0x00 }; + INT32 res = mmGetEnv("PARSER_PRIORITY", parser_priority, MMPA_MAX_PATH); + if (res == EN_OK && string(parser_priority) == "cce") { + if (dims.size() != kDimSize4d) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcElementCntOfFractalZ", + "dims.size[" + std::to_string(dims.size()) + "] is not " + std::to_string(kDimSize4d)}); + GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d); + return GRAPH_FAILED; + } + auto c0 = static_cast(GetDimC0(data_type)); + // FractalZ dims is according to nchw, dim c index is 1. + auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); + + // Spread NC1HWC0 as a two dimension array, n as column dimension, + // C1HWC0 as row dimension + std::vector r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; + + int64_t r_count = 1; + graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", + c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0); + return graph_status; + } + + // Cube count in n + auto nc_cnt = static_cast(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize)); + + // Cube count in vertical direction(C1HWC0) + int64_t vc_cnt = r_count / c0; + // Element count in each cube + int64_t cube_elem_cnt = c0 * kTheCubeSize; + + if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(nc_cnt), std::to_string(vc_cnt)}); + GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt); + return GRAPH_FAILED; + } + // Read data times needed by cube + int64_t c_cnt = nc_cnt * vc_cnt; + + if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(c_cnt), std::to_string(cube_elem_cnt)}); + GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt); + return GRAPH_FAILED; + } + // Element count after fractal arrangement + element_cnt = c_cnt * cube_elem_cnt; + return GRAPH_SUCCESS; + } else { + return CalcElementCntByDims(dims, element_cnt); + } +} + +/// +/// Calculate tensor element num. +/// @param dims dim info +/// @param format tensor format +/// @param data_type data type +/// @param element_cnt element count +/// @return GRAPH_SUCCESS:success +/// other:failed +/// +static graphStatus CalcTensorElementCnt(const std::vector &dims, Format format, DataType data_type, + int64_t &element_cnt) { + const string format_str = TypeUtils::FormatToSerialString(format); + // Check dims + for (size_t i = 0; i < dims.size(); ++i) { + int64_t dim = dims[i]; + if (dim < 0) { + GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str()); + element_cnt = kElementCntUnknownShape; + return GRAPH_SUCCESS; + } else if (dim == 0) { + GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str()); + element_cnt = 0; + return GRAPH_SUCCESS; + } + } + + graphStatus graph_status; + switch (format) { + case FORMAT_ND: + case FORMAT_MD: + graph_status = CalcElementCntByDims(dims, element_cnt); + break; + case FORMAT_NCHW: + case FORMAT_HWCN: + case FORMAT_NHWC: + case FORMAT_CHWN: + graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); + break; + case FORMAT_C1HWNCoC0: + graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); + break; + case FORMAT_NC1HWC0: + graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt); + break; + case FORMAT_FRACTAL_Z: + graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); + break; + case FORMAT_FRACTAL_NZ: + case FORMAT_FRACTAL_ZZ: + case FORMAT_NDHWC: + case FORMAT_NCDHW: + case FORMAT_DHWCN: + case FORMAT_DHWNC: + case FORMAT_FRACTAL_Z_3D: + case FORMAT_FRACTAL_Z_3D_TRANSPOSE: + case FORMAT_NDC1HWC0: + case FORMAT_FRACTAL_Z_C04: + case FORMAT_FRACTAL_ZN_LSTM: + case FORMAT_NC1HWC0_C04: + graph_status = CalcElementCntByDims(dims, element_cnt); + break; + default: + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"CalcTensorElementCnt", "format[" + format_str + "] is not support"}); + GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); + graph_status = GRAPH_FAILED; + break; + } + + const string type_str = TypeUtils::DataTypeToSerialString(data_type); + if (graph_status == GRAPH_SUCCESS) { + GELOGD( + "CalcTensorElementCnt end, format=%d(%s)," + " data_type=%d(%s), element_cnt=%ld.", + format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); + } else { + GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", + format, format_str.c_str(), data_type, type_str.c_str()); + } + return graph_status; +} + +/// +/// Calculate tensor mem size. +/// @param shape tensor shape +/// @param format tensor format +/// @param data_type tensor data type +/// @param mem_size -1 means unknown shape,other means mem size +/// @return GRAPH_SUCCESS:success, other:failed +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, + Format format, + DataType data_type, + int64_t &mem_size) { + const string format_str = TypeUtils::FormatToSerialString(format); + const string type_str = TypeUtils::DataTypeToSerialString(data_type); + uint32_t type_size = 0; + bool result = TypeUtils::GetDataTypeLength(data_type, type_size); + if (!result) { + GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); + return GRAPH_FAILED; + } + + std::vector dims = shape.GetDims(); + int64_t element_cnt = 0; + graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", + status, format, format_str.c_str(), data_type, type_str.c_str()); + return status; + } + // Support unknown shape + if (element_cnt < 0) { + mem_size = kMemSizeUnknownShape; + GELOGD( + "element_cnt is unknown. " + "format=%d(%s), data_type=%d(%s), mem_size=%ld", + format, format_str.c_str(), data_type, type_str.c_str(), mem_size); + return GRAPH_SUCCESS; + } + auto type_size_int64 = static_cast(type_size); + if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(type_size_int64)}); + GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).", + element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str()); + return GRAPH_FAILED; + } + mem_size = element_cnt * type_size_int64; + + GELOGD( + "CalcTensorMemSize end, " + "format=%d(%s), data_type=%d(%s), mem_size=%ld", + format, format_str.c_str(), data_type, type_str.c_str(), mem_size); + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { + graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); + if (graph_status != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + // 64-byte alignment, if size is 0, align to 32 bytes + if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { + GELOGW("The updated mem size %ld is bigger than INT64_MAX",size_temp); + } else { + size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; + } + return GRAPH_SUCCESS; +} +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { + GeShape output_shape = desc_temp.GetShape(); + Format format = desc_temp.GetFormat(); + DataType data_type = desc_temp.GetDataType(); + int64_t output_mem_size = 0; + graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size); + if (graph_status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!"); + return GRAPH_FAILED; + } + + if (output_mem_size < 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"GetTensorSizeInBytes", "output_mem_size is out of data range [0, " + std::to_string(INT64_MAX) + "]"}); + GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", + output_mem_size, INT64_MAX); + return GRAPH_FAILED; + } + + size_temp = output_mem_size; + return GRAPH_SUCCESS; +} +} // namespace ge diff --git a/metadef/graph/utils/transformer_utils.cc b/metadef/graph/utils/transformer_utils.cc new file mode 100644 index 00000000..17f166fd --- /dev/null +++ b/metadef/graph/utils/transformer_utils.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transformer_utils.h" + +#include "external/ge/ge_api_types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/type_utils.h" + +namespace ge { +bool NodeShapeTransUtils::CatchFormatAndShape() { + auto inputs = op_desc_->MutableAllInputName(); + auto outputs = op_desc_->MutableAllOutputName(); + + for (auto &ele : inputs) { + auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first); + if (tensor_desc_input == nullptr) { + continue; + } + auto format = tensor_desc_input->GetFormat(); + auto ori_format = tensor_desc_input->GetOriginFormat(); + if (format == ori_format) { + GELOGD("Node is %s, input tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!", + op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(), + TypeUtils::FormatToSerialString(format).c_str()); + continue; + } + map_format_in_.insert(std::pair(ele.first, format)); + map_ori_format_in_.insert(std::pair(ele.first, ori_format)); + map_dtype_in_.insert(std::pair(ele.first, tensor_desc_input->GetDataType())); + tensor_desc_input->SetFormat(ori_format); + tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape()); + } + + for (auto &ele : outputs) { + auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first); + if (tensor_desc_output == nullptr) { + continue; + } + auto format = tensor_desc_output->GetFormat(); + auto ori_format = tensor_desc_output->GetOriginFormat(); + if (format == ori_format) { + GELOGD("Node is %s, output tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!", + op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(), + TypeUtils::FormatToSerialString(format).c_str()); + continue; + } + map_format_out_.insert(std::pair(ele.first, format)); + map_ori_format_out_.insert(std::pair(ele.first, ori_format)); + map_dtype_out_.insert(std::pair(ele.first, tensor_desc_output->GetDataType())); + + if (format == ori_format) { + continue; + } + tensor_desc_output->SetFormat(ori_format); + } + + return true; +} + +bool NodeShapeTransUtils::UpdateFormatAndShape() { + auto inputs = op_desc_->MutableAllInputName(); + auto outputs = op_desc_->MutableAllOutputName(); + + for (auto &ele : inputs) { + auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first); + if (tensor_desc_input == nullptr) { + continue; + } + // if can not find saved info, it says format and origin format is same when catched + if (map_format_in_.find(ele.first) == map_format_in_.end()) { + GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!", + op_desc_->GetName().c_str(), ele.first.c_str()); + tensor_desc_input->SetOriginFormat(tensor_desc_input->GetFormat()); + tensor_desc_input->SetOriginShape(tensor_desc_input->GetShape()); + continue; + } + auto ori_format = tensor_desc_input->GetFormat(); + auto ori_shape = tensor_desc_input->GetShape(); + auto curr_format = map_format_in_[ele.first]; + if (ori_format == curr_format) { + continue; + } + std::unique_ptr shape_transfer(new(std::nothrow) + common::transformer::ShapeTransferAccordingToFormat()); + if (shape_transfer == nullptr) { + GELOGE(GRAPH_FAILED, "Memory alloc failed"); + return false; + } + std::vector ori_shape_dims = ori_shape.GetDims(); + std::vector out_dims; + ge::DataType dtype = map_dtype_in_[ele.first]; + common::transformer::ShapeAndFormat shape_and_format_info {ori_shape_dims, out_dims, ori_format, curr_format, dtype, + common::transformer::EN_IMPL_CUSTOM_TBE}; + shape_transfer->GetShapeAccordingToFormat(shape_and_format_info); + tensor_desc_input->SetFormat(curr_format); + tensor_desc_input->SetShape(GeShape(out_dims)); + } + + for (auto &ele : outputs) { + auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first); + if (tensor_desc_output == nullptr) { + continue; + } + // if can not find saved info, it says format and origin format is same when catched + if (map_ori_format_out_.find(ele.first) == map_ori_format_out_.end()) { + GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!", + op_desc_->GetName().c_str(), ele.first.c_str()); + tensor_desc_output->SetOriginFormat(tensor_desc_output->GetFormat()); + tensor_desc_output->SetOriginShape(tensor_desc_output->GetShape()); + continue; + } + auto ori_shape = tensor_desc_output->GetShape(); + auto curr_format = tensor_desc_output->GetFormat(); + if (curr_format != map_ori_format_out_[ele.first]) { + GELOGE(GRAPH_FAILED, "Node is %s, out tensor name is %s. format: %s, recorded origin format: %s is not same", + op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), + TypeUtils::FormatToSerialString(map_ori_format_out_[ele.first]).c_str()); + return GRAPH_FAILED; + } + tensor_desc_output->SetOriginShape(ori_shape); + auto saved_format = map_format_out_[ele.first]; + if (curr_format == saved_format) { + GELOGD("Nodeis %s, out tensor name is %s. ori format: %s, recorded format: %s is same! No need to transfer", + op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), + TypeUtils::FormatToSerialString(saved_format).c_str()); + continue; + } + tensor_desc_output->SetFormat(saved_format); + std::unique_ptr shape_transfer(new(std::nothrow) + common::transformer::ShapeTransferAccordingToFormat()); + if (shape_transfer == nullptr) { + GELOGE(GRAPH_FAILED, "Memory alloc failed"); + return false; + } + std::vector ori_shape_dims = ori_shape.GetDims(); + std::vector out_dims; + ge::DataType dtype = tensor_desc_output->GetDataType(); + common::transformer::ShapeAndFormat shape_and_format_info {ori_shape_dims, out_dims, curr_format, saved_format, + dtype, common::transformer::EN_IMPL_CUSTOM_TBE}; + shape_transfer->GetShapeAccordingToFormat(shape_and_format_info); + tensor_desc_output->SetShape(GeShape(out_dims)); + GELOGD("Node is %s, out tensor name is %s. Update format and shape success,ori format: %s, format: %s", + op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), + TypeUtils::FormatToSerialString(saved_format).c_str()); + } + GELOGD("Node is %s. Update format and shape success", op_desc_->GetName().c_str()); + return true; +} +} // namespace ge \ No newline at end of file diff --git a/metadef/graph/utils/transformer_utils.h b/metadef/graph/utils/transformer_utils.h new file mode 100644 index 00000000..10bff82c --- /dev/null +++ b/metadef/graph/utils/transformer_utils.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ +#define COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ +#include +#include + +#include "external/graph/types.h" +#include "graph/op_desc.h" +#include "graph/ge_tensor.h" +#include "transformer/inc/transfer_shape_according_to_format.h" + +namespace ge { +class NodeShapeTransUtils { + public: + bool CatchFormatAndShape(); + bool UpdateFormatAndShape(); + + explicit NodeShapeTransUtils(OpDescPtr op_desc) : op_desc_(op_desc) { + } + + ~NodeShapeTransUtils() { + } + + private: + std::map map_format_in_; + std::map map_ori_format_in_; + std::map map_dtype_in_; + std::map map_format_out_; + std::map map_ori_format_out_; + std::map map_dtype_out_; + + OpDescPtr op_desc_; +}; +} // namespace ge +#endif // COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ \ No newline at end of file diff --git a/metadef/graph/utils/tuning_utils.cc b/metadef/graph/utils/tuning_utils.cc new file mode 100644 index 00000000..d591910c --- /dev/null +++ b/metadef/graph/utils/tuning_utils.cc @@ -0,0 +1,794 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/tuning_utils.h" +#include "../debug/ge_util.h" +#include "../debug/ge_op_types.h" +#include "framework/common/scope_guard.h" + +namespace ge { +namespace { +const std::string peer_node_name_attr = "_peerNodeName"; +const std::string parent_node_name_attr = "_parentNodeName"; +const std::string alias_name_attr = "_aliasName"; +const std::string parent_node_attr = "parentNode"; +const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; +const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; +const std::string non_tuning_subgraph_prefix = "/subgraph_"; +const std::set kPartitionOpTypes = {PLACEHOLDER, END}; +const std::set kExeTypes = {DATA, NETOUTPUT}; +} +NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; +NodetoNodeNameMap TuningUtils::data_node_2_netoutput_ ; +NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; +NodeVec TuningUtils::netoutput_nodes_; +NodeVec TuningUtils::merged_graph_nodes_; +SubgraphCreateOutNode TuningUtils::create_output_; +std::mutex TuningUtils::mutex_; + +std::string TuningUtils::PrintCheckLog() { + std::stringstream ss; + ss << "d2n:{"; + for (const auto &pair : data_2_netoutput_) { + ss << "data:" << pair.first << "-" << "netoutput:" << pair.second; + ss << " | "; + } + ss << "}"; + ss << "netoutputs:{"; + for (const auto &node : netoutput_nodes_) { + ss << "netoutput:" << node->GetName(); + ss << " | "; + } + ss << "}"; + return ss.str(); +} + +std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Anchor is nullptr"); + return "Null"; + } + auto node = anchor->GetOwnerNode(); + return node == nullptr ? "Null" : node->GetName(); +} + +// part 1 +graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_subgraphs, + std::vector non_tuning_subgraphs, + bool exe_flag, const std::string &path, const std::string &user_path) { + int64_t i = 0; + int64_t j = 0; + std::lock_guard lock(mutex_); + for (auto &subgraph : tuning_subgraphs) { + create_output_.emplace(subgraph, nullptr); + auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; + if (MakeExeGraph(subgraph, help_info) != SUCCESS) { + GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); + return GRAPH_FAILED; + } + i++; + } + + for (auto &subgraph : non_tuning_subgraphs) { + create_output_.emplace(subgraph, nullptr); + auto help_info = HelpInfo{j, true, false, path, user_path}; + if (MakeExeGraph(subgraph, help_info) != SUCCESS) { + GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); + return GRAPH_FAILED; + } + j++; + } + create_output_.clear(); + return SUCCESS; +} + +// +---------------+ +// | pld pld | +// | \ / | +// | relu relu | +// | \ / | +// | add | +// | | | +// | end | +// +---------------+ +// | +// | +// V +// +---------------+ +// | data data | +// | \ / | +// | relu relu | +// | \ / | +// | add | +// | | | +// | netoutput | +// +---------------+ +graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, + const HelpInfo& help_info) { + GE_CHECK_NOTNULL(exe_graph); + graphStatus ret = exe_graph->TopologicalSortingGraph(true); + if (ret != SUCCESS) { + GraphUtils::DumpGEGraphToOnnx(*exe_graph, "black_box"); + GELOGE(ret, "Graph[%s] topological sort failed, saved to file black_box ret:%d.", + exe_graph->GetName().c_str(), ret); + return ret; + } + // clear graph id + GELOGI("TUU:clear [%s] session_graph_id %s", exe_graph->GetName().c_str(), + (AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "") ? "success" : "not success")); + // if not make exe, just dump and return + if (!help_info.exe_flag) { + DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); + GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); + return SUCCESS; + } + // modify sub graph + for (NodePtr &node : exe_graph->GetDirectNode()) { + // 1.handle pld + if (node->GetType() == PLACEHOLDER) { + if (HandlePld(node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), + exe_graph->GetName().c_str()); + return FAILED; + } + } + // 2.handle end + if (node->GetType() == END) { + if (HandleEnd(node) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), + exe_graph->GetName().c_str()); + return FAILED; + } + } + } + ret = exe_graph->TopologicalSortingGraph(true); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); + return ret; + } + // dump subgraphs which modified by us + if (help_info.user_path.empty()) { + DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); + } + return SUCCESS; +} + +void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, + bool is_tuning_graph, std::string path) { + if (!path.empty()) { + if (is_tuning_graph) { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } + } else { + path = "./"; + if (is_tuning_graph) { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } else { + GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); + } + } +} + +graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto data_op_desc = ComGraphMakeShared(node->GetName(), DATA); + GE_CHECK_NOTNULL(data_op_desc); + auto pld_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(pld_op_desc); + auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data + // data inputdesc & outputdesc set as same + if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { + GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); + return FAILED; + } + if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { + GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); + return FAILED; + } + data_node = graph->AddNode(data_op_desc); + GE_CHECK_NOTNULL(data_node); + if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); + return FAILED; + } + return SUCCESS; +} + +graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + auto pld_desc = pld->GetOpDesc(); + GE_CHECK_NOTNULL(pld_desc); + // inherit + // a. set `end's input node type` as attr + std::string parent_op_type; + if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { + GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void) AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); + // b. set `end's input node name` as attr + std::string parent_op_name; + if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { + GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void) AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); + // c. set `end's input node's out anchor index` as attr + int parent_node_anchor_index; + if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { + GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void) AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); + GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", + pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); + // d. set `end node name` as attr + std::string peer_end_name; + if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { + GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); + return FAILED; + } + (void) AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); + GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", + pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { + auto type_pld = node->GetType(); + auto type_data = data_node->GetType(); + if (type_pld != PLACEHOLDER || type_data != DATA) { + GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", + node->GetName().c_str(), type_pld.c_str(), type_data.c_str()); + return FAILED; + } + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + std::vector output_map(node->GetAllOutDataAnchorsSize()); + for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { + output_map[i] = static_cast(i); + } + + auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", + node->GetName().c_str(), data_node->GetName().c_str(), ret); + return FAILED; + } + + NodeUtils::UnlinkAll(*node); + + ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); + return FAILED; + } + + GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", + node->GetName().c_str(), node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); + return ret; +} + +graphStatus TuningUtils::HandlePld(NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + if (HandleContinuousInputNodeNextData(node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "TUU:Failed to handle continuous node next to data node:%s", + node->GetName().c_str()); + return GRAPH_FAILED; + } + + NodePtr data_node = nullptr; + // 1. create data node + if (CreateDataNode(node, data_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + // 2. add necessary info to data_node for recovery whole graph + if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + // 3. replace pld node by data node created before + if (ChangePld2Data(node, data_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + auto search = create_output_.find(graph); + if (search == create_output_.end()) { + GELOGE(FAILED, + "TUU:node %s's owner sub graph %s not exist in create_output map", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + if (search->second != nullptr) { + out_node = search->second; + GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); + return SUCCESS; + } + auto out_op_desc = ComGraphMakeShared(node->GetName(), NETOUTPUT); + GE_CHECK_NOTNULL(out_op_desc); + out_node = graph->AddNode(out_op_desc); + GE_CHECK_NOTNULL(out_node); + if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); + return FAILED; + } + create_output_[graph] = out_node; + return SUCCESS; +} + +graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { + GE_CHECK_NOTNULL(end); + GE_CHECK_NOTNULL(out_node); + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::vector alias_names = {}; + (void) AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); + alias_names.push_back(end->GetName()); + (void) AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); + return SUCCESS; +} + +graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { + GE_CHECK_NOTNULL(end_node); + GE_CHECK_NOTNULL(out_node); + // get end in node is control node or normal node + AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) + ? Anchor::DynamicAnchorCast(end_node->GetInControlAnchor()) + : Anchor::DynamicAnchorCast(end_node->GetInDataAnchor(0)); + GE_CHECK_NOTNULL(end_in_anchor); + auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 + GE_CHECK_NOTNULL(src_anchor); + if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), + end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + // add edge between `end in node` and `out_node` + if (src_anchor->IsTypeOf()) { + std::shared_ptr + anchor = ComGraphMakeShared(out_node, out_node->GetAllInDataAnchors().size()); + GE_CHECK_NOTNULL(anchor); + out_node->in_data_anchors_.push_back(anchor); + if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), + end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + auto end_op_desc = end_node->GetOpDesc(); + GE_CHECK_NOTNULL(end_op_desc); + auto out_node_op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_node_op_desc); + // end node always has one input + if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); + return FAILED; + } + } else if (src_anchor->IsTypeOf()) { + auto anchor = out_node->GetInControlAnchor(); + if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", + GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), + GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), + end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + } else { + GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", + end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { + GE_CHECK_NOTNULL(end_node); + GE_CHECK_NOTNULL(out_node); + auto type_end = end_node->GetType(); + auto type_out = out_node->GetType(); + if (type_end != END || type_out != NETOUTPUT) { + GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", + end_node->GetName().c_str(), type_end.c_str(), type_out.c_str()); + return FAILED; + } + // link all `end nodes's in node` to this out_node + if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { + GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); + return FAILED; + } + // remove `end node` + NodeUtils::UnlinkAll(*end_node); + auto graph = end_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { + GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +graphStatus TuningUtils::HandleEnd(NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + NodePtr out_node = nullptr; + + // 1. create net_output node , add only one NetOutput node to one subgraph + if (CreateNetOutput(node, out_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + // 2. add necessary info to out_node for recovery whole graph + if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + // 3. replace all end nodes by one output node created before + if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { + GELOGE(FAILED, + "TUU:Failed to handle node %s from graph %s", + node->GetName().c_str(), + graph->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); + return SUCCESS; +} + +// part 2 +graphStatus TuningUtils::ConvertFileToGraph(const map &options, ge::Graph &graph) { + std::function callback = [&]() { + data_2_netoutput_.clear(); + data_node_2_netoutput_.clear(); + data_node_2_netoutput_node_.clear(); + netoutput_nodes_.clear(); + merged_graph_nodes_.clear(); + }; + GE_MAKE_GUARD(release, callback); + // 1. get all subgraph object + std::vector graphs; + // options format like {index:"subgraph_path"} + for (const auto &pair : options) { + ComputeGraphPtr compute_graph = ComGraphMakeShared(std::to_string(pair.first)); + if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { + GELOGE(FAILED, "TUU:load graph from file failed"); + } + graphs.push_back(compute_graph); + } + // 2. merge graph + ComputeGraphPtr merged_graph = ComGraphMakeShared("whole_graph_after_tune"); + GE_CHECK_NOTNULL(merged_graph); + if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { + GELOGE(FAILED, "TUU:MergeGraph failed"); + return FAILED; + } + // 3. set parent graph + for (const auto &node : merged_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); + return FAILED; + } + } + graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); + return SUCCESS; +} + +// +----------------------------------+ +// | const const | +// | \ / | +// | netoutput(end,end) | +// +----------------------------------+ +// + +// +----------------------------------+ +// | data(pld) data(pld) | +// | \ / | +// | relu relu | +// | \ / | +// | \ / | +// | add | +// | | | +// | netoutput(end) | +// +----------------------------------+ +// + +// +----------------------------------+ +// | data(pld) | +// | / | +// | netoutput | +// +----------------------------------+ +// | +// | +// V +// +----------------------------------+ +// | const const | +// | \ / | +// | relu relu | +// | \ / | +// | \ / | +// | add | +// | | | +// | netoutput | +// +----------------------------------+ +graphStatus TuningUtils::MergeAllSubGraph(std::vector &subgraphs, + ComputeGraphPtr &output_merged_compute_graph) { + GE_CHECK_NOTNULL(output_merged_compute_graph); + // 1. handle all subgraphs + for (auto &subgraph : subgraphs) { + Status ret_status = MergeSubGraph(subgraph); + if (ret_status != SUCCESS) { + GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); + return ret_status; + } + } + + for (const auto &node: merged_graph_nodes_) { + (void) output_merged_compute_graph->AddNode(node); + GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); + + vector recover_attr_name; + (void) ge::AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_NEED_RECOVER_ATTR, recover_attr_name); + if (!recover_attr_name.empty()) { + for (const auto &attr_name : recover_attr_name) { + if (!ge::AttrUtils::SetBool(node->GetOpDesc(), attr_name, true)) { + GELOGE(GRAPH_FAILED, "Recover attr %s for node:%s failed.", attr_name.c_str(), node->GetName().c_str()); + return GRAPH_FAILED; + } + } + } + } + + // 2. remove data and output node added by us + if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { + GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); + return FAILED; + } + graphStatus ret = output_merged_compute_graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); + return ret; + } + GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); + GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { + for (auto &node : subgraph->GetDirectNode()) { + if (kPartitionOpTypes.count(node->GetType()) > 0) { + GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); + return FAILED; + } + // handle data converted from pld node + if (node->GetType() == DATA) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string peer_out_name; + bool has_valid_str = + (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); + if (has_valid_str) { + std::lock_guard lock(mutex_); + data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); + data_node_2_netoutput_.emplace(node, peer_out_name); + continue; + } + } + // handle netoutput converted from end node + if (node->GetType() == NETOUTPUT) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::vector out_alias_name; + bool has_valid_str = + (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); + if (has_valid_str) { + std::lock_guard lock(mutex_); + netoutput_nodes_.emplace_back(node); + } + } + { + std::lock_guard lock(mutex_); + merged_graph_nodes_.emplace_back(node); + } + GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); + } + GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); + return SUCCESS; +} + +graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + // 1. traverse + for (auto &pair: data_node_2_netoutput_) { + auto data_node = pair.first; + GE_CHECK_NOTNULL(data_node); + auto netoutput_name = pair.second; + auto netoutput_node = graph->FindNode(netoutput_name); + GE_CHECK_NOTNULL(netoutput_node); + data_node_2_netoutput_node_.emplace(data_node, netoutput_node); + // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` + AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) + ? Anchor::DynamicAnchorCast(data_node->GetOutControlAnchor()) + : Anchor::DynamicAnchorCast(data_node->GetOutDataAnchor(0)); + AnchorPtr net_output_in_anchor = nullptr; + AnchorPtr src_out_anchor = nullptr; + if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", + netoutput_node->GetName().c_str(), data_node->GetName().c_str()); + return FAILED; + } + // 3. relink + // unlink netoutput_node with it's input in stage 4 + GE_CHECK_NOTNULL(data_out_anchor); + for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { + if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", + GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), + data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", + GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), + GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), + data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); + return FAILED; + } + } + } + // 4. remove out nodes added by us + for (auto &node: netoutput_nodes_) { + NodeUtils::UnlinkAll(*node); + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); + return FAILED; + } + GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); + } + return SUCCESS; +} + +graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, + NodePtr &out_node, + AnchorPtr &dest_in_anchor, + AnchorPtr &src_out_anchor) { + // 1. get `data parent node name`, i.e. `netoutput input node name` + std::string netoutput_input_name; + auto op_desc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { + GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); + return FAILED; + } + // 2. find index + int parent_node_anchor_index; + if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { + GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); + return FAILED; + } + // 3.find in data or ctrl anchor by 1&2 step + for (auto &in_anchor: out_node->GetAllInAnchors()) { + GE_CHECK_NOTNULL(in_anchor); + for (auto &src_anchor :in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl + GE_CHECK_NOTNULL(src_anchor); + auto src_node = src_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + std::string src_node_name = src_node->GetName(); + if (src_node_name.find(netoutput_input_name) != src_node_name.npos && + src_anchor->GetIdx() == parent_node_anchor_index) { + dest_in_anchor = in_anchor; + src_out_anchor = src_anchor; + GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", + out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), + parent_node_anchor_index, data_node->GetName().c_str()); + break; + } + } + } + GE_CHECK_NOTNULL(dest_in_anchor); + GE_CHECK_NOTNULL(src_out_anchor); + return SUCCESS; +} + +graphStatus TuningUtils::HandleContinuousInputNodeNextData(NodePtr &node) { + GE_CHECK_NOTNULL(node); + for (const auto &out_anchor : node->GetAllOutAnchors()) { + for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { + auto next_node = peer_in_anchor->GetOwnerNode(); + vector remove_attr_names; + bool is_no_padding_continuous_input = false; + bool is_continuous_input = false; + bool is_no_task = false; + (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_continuous_input); + (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(), + ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, + is_no_padding_continuous_input); + (void) ge::AttrUtils::GetBool(next_node->GetOpDesc(), ATTR_NAME_NOTASK, is_no_task); + if (is_continuous_input) { + if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, false)) { + GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_CONTINUOUS_INPUT for node:%s failed.", + next_node->GetName().c_str()); + return GRAPH_FAILED; + } + remove_attr_names.emplace_back(ATTR_NAME_CONTINUOUS_INPUT); + } + if (is_no_padding_continuous_input) { + if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, false)) { + GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_NOPADDING_CONTINUOUS_INPUT for node:%s failed.", + next_node->GetName().c_str()); + return GRAPH_FAILED; + } + remove_attr_names.emplace_back(ATTR_NAME_NOPADDING_CONTINUOUS_INPUT); + } + if ((is_continuous_input || is_no_padding_continuous_input) && is_no_task) { + if (!ge::AttrUtils::SetBool(next_node->GetOpDesc(), ATTR_NAME_NOTASK, false)) { + GELOGE(GRAPH_FAILED, "Remove attr ATTR_NAME_NOTASK for node:%s failed.", + next_node->GetName().c_str()); + return GRAPH_FAILED; + } + remove_attr_names.emplace_back(ATTR_NAME_NOTASK); + } + if (!remove_attr_names.empty()) { + if (!ge::AttrUtils::SetListStr(next_node->GetOpDesc(), + ATTR_NAME_NEED_RECOVER_ATTR, + remove_attr_names)) { + GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_NEED_RECOVER_ATTR for node:%s failed.", + next_node->GetName().c_str()); + return GRAPH_FAILED; + } + } + } + } + return GRAPH_SUCCESS; +} +} diff --git a/metadef/graph/utils/type_utils.cc b/metadef/graph/utils/type_utils.cc new file mode 100644 index 00000000..cd2f8545 --- /dev/null +++ b/metadef/graph/utils/type_utils.cc @@ -0,0 +1,509 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/type_utils.h" + +#include + +#include "debug/ge_util.h" +#include "common/util/error_manager/error_manager.h" + +using domi::domiTensorFormat_t; + +namespace ge { +namespace{ +const std::map kFormatToStringMap = { + {FORMAT_NCHW, "NCHW"}, + {FORMAT_NHWC, "NHWC"}, + {FORMAT_ND, "ND"}, + {FORMAT_NC1HWC0, "NC1HWC0"}, + {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, + {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, + {FORMAT_NHWC1C0, "NHWC1C0"}, + {FORMAT_FSR_NCHW, "FSR_NCHW"}, + {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, + {FORMAT_C1HWNC0, "C1HWNC0"}, + {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, + {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, + {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, + {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, + {FORMAT_CHWN, "CHWN"}, + {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, + {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, + {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, + {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, + {FORMAT_HWCN, "HWCN"}, + {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, + {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, + {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, + {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, + {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, + {FORMAT_MD, "MD"}, + {FORMAT_NDHWC, "NDHWC"}, + {FORMAT_NCDHW, "NCDHW"}, + {FORMAT_DHWCN, "DHWCN"}, + {FORMAT_DHWNC, "DHWNC"}, + {FORMAT_NDC1HWC0, "NDC1HWC0"}, + {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, + {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, + {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, + {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, + {FORMAT_CN, "CN"}, + {FORMAT_NC, "NC"}, + {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, + {FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"}, + {FORMAT_RESERVED, "FORMAT_RESERVED"}, + {FORMAT_ALL, "ALL"}}; + +const std::map kDomiFormatToGeFormat = { + {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, + {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, + {domi::DOMI_TENSOR_ND, FORMAT_ND}, + {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, + {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, + {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, + {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, + {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, + {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, + {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, + {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, + {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, + {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, + {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, + {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, + {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, + {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED} +}; + +const std::unordered_set kInternalFormat = { + "NC1HWC0", + "FRACTAL_Z", + "NC1C0HWPAD", + "NHWC1C0", + "FRACTAL_DECONV", + "C1HWNC0", + "FRACTAL_DECONV_TRANSPOSE", + "FRACTAL_DECONV_SP_STRIDE_TRANS", + "NC1HWC0_C04", + "FRACTAL_Z_C04", + "FRACTAL_DECONV_SP_STRIDE8_TRANS", + "NC1KHKWHWC0", + "C1HWNCoC0", + "FRACTAL_ZZ", + "FRACTAL_NZ", + "NDC1HWC0", + "FORMAT_FRACTAL_Z_3D", + "FORMAT_FRACTAL_Z_3D_TRANSPOSE", + "FORMAT_FRACTAL_ZN_LSTM", + "FORMAT_FRACTAL_Z_G" +}; + +const std::map kDataFormatMap = { + {"NCHW", FORMAT_NCHW}, + {"NHWC", FORMAT_NHWC}, + {"NDHWC", FORMAT_NDHWC}, + {"NCDHW", FORMAT_NCDHW}, + {"ND", FORMAT_ND}}; + +const std::map kStringToFormatMap = { + {"NCHW", FORMAT_NCHW}, + {"NHWC", FORMAT_NHWC}, + {"ND", FORMAT_ND}, + {"NC1HWC0", FORMAT_NC1HWC0}, + {"FRACTAL_Z", FORMAT_FRACTAL_Z}, + {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, + {"NHWC1C0", FORMAT_NHWC1C0}, + {"FSR_NCHW", FORMAT_FSR_NCHW}, + {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, + {"C1HWNC0", FORMAT_C1HWNC0}, + {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, + {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, + {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, + {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, + {"CHWN", FORMAT_CHWN}, + {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, + {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, + {"BN_WEIGHT", FORMAT_BN_WEIGHT}, + {"FILTER_HWCK", FORMAT_FILTER_HWCK}, + {"HWCN", FORMAT_HWCN}, + {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, + {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, + {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, + {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, + {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, + {"MD", FORMAT_MD}, + {"C1HWNCoC0", FORMAT_C1HWNCoC0}, + {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, + {"NDHWC", FORMAT_NDHWC}, + {"NCDHW", FORMAT_NCDHW}, + {"DHWCN", FORMAT_DHWCN}, + {"DHWNC", FORMAT_DHWNC}, + {"NDC1HWC0", FORMAT_NDC1HWC0}, + {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, + {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, + {"CN", FORMAT_CN}, + {"NC", FORMAT_NC}, + {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, + {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, + {"FORMAT_RESERVED", FORMAT_RESERVED}, + {"ALL", FORMAT_ALL}, + {"NULL", FORMAT_NULL}}; + +const std::map kDataTypeToStringMap = { + {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. + {DT_FLOAT, "DT_FLOAT"}, // float type + {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type + {DT_INT8, "DT_INT8"}, // int8 type + {DT_INT16, "DT_INT16"}, // int16 type + {DT_UINT16, "DT_UINT16"}, // uint16 type + {DT_UINT8, "DT_UINT8"}, // uint8 type + {DT_INT32, "DT_INT32"}, // uint32 type + {DT_INT64, "DT_INT64"}, // int64 type + {DT_UINT32, "DT_UINT32"}, // unsigned int32 + {DT_UINT64, "DT_UINT64"}, // unsigned int64 + {DT_BOOL, "DT_BOOL"}, // bool type + {DT_DOUBLE, "DT_DOUBLE"}, // double type + {DT_DUAL, "DT_DUAL"}, // dual output type + {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type + {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type + {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type + {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type + {DT_QINT8, "DT_QINT8"}, // qint8 type + {DT_QINT16, "DT_QINT16"}, // qint16 type + {DT_QINT32, "DT_QINT32"}, // qint32 type + {DT_QUINT8, "DT_QUINT8"}, // quint8 type + {DT_QUINT16, "DT_QUINT16"}, // quint16 type + {DT_RESOURCE, "DT_RESOURCE"}, // resource type + {DT_STRING_REF, "DT_STRING_REF"}, // string ref type + {DT_STRING, "DT_STRING"}, // string type +}; + +const std::map kStringTodataTypeMap = { + {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. + {"DT_FLOAT", DT_FLOAT}, // float type + { + "DT_FLOAT16", DT_FLOAT16, + }, // fp16 type + {"DT_INT8", DT_INT8}, // int8 type + {"DT_INT16", DT_INT16}, // int16 type + {"DT_UINT16", DT_UINT16}, // uint16 type + {"DT_UINT8", DT_UINT8}, // uint8 type + {"DT_INT32", DT_INT32}, // uint32 type + {"DT_INT64", DT_INT64}, // int64 type + {"DT_UINT32", DT_UINT32}, // unsigned int32 + {"DT_UINT64", DT_UINT64}, // unsigned int64 + {"DT_BOOL", DT_BOOL}, // bool type + {"DT_DOUBLE", DT_DOUBLE}, // double type + {"DT_DUAL", DT_DUAL}, // dual output type + {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type + {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type + {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type + {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type + {"DT_QINT8", DT_QINT8}, // qint8 type + {"DT_QINT16", DT_QINT16}, // qint16 type + {"DT_QINT32", DT_QINT32}, // qint32 type + {"DT_QUINT8", DT_QUINT8}, // quint8 type + {"DT_QUINT16", DT_QUINT16}, // quint16 type + {"DT_RESOURCE", DT_RESOURCE}, // resource type + {"DT_STRING_REF", DT_STRING_REF}, // string ref type + {"DT_STRING", DT_STRING}, // string type + // add for json input + {"DT_FLOAT32", DT_FLOAT} +}; + +const std::map kDataTypeToLength = { + {DT_BOOL, sizeof(bool)}, + {DT_INT64, sizeof(int64_t)}, + {DT_UINT64, sizeof(int64_t)}, + {DT_FLOAT, sizeof(float)}, + {DT_INT32, sizeof(int32_t)}, + {DT_UINT32, sizeof(int32_t)}, + {DT_INT8, sizeof(char)}, + {DT_UINT8, sizeof(char)}, + {DT_INT16, sizeof(int16_t)}, + {DT_UINT16, sizeof(int16_t)}, + {DT_FLOAT16, sizeof(int16_t)}, + {DT_DOUBLE, sizeof(double)}, + {DT_DUAL, sizeof(float) + sizeof(int8_t)}, + {DT_DUAL_SUB_INT8, sizeof(int8_t)}, + {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, + {DT_COMPLEX64, sizeof(int64_t)}, + {DT_COMPLEX128, sizeof(int64_t) * 2}, + {DT_QINT8, sizeof(int8_t)}, + {DT_QINT16, sizeof(int16_t)}, + {DT_QINT32, sizeof(int32_t)}, + {DT_QUINT8, sizeof(uint8_t)}, + {DT_QUINT16, sizeof(uint16_t)}, + {DT_STRING_REF, sizeof(uint64_t) * 2}, + {DT_STRING, sizeof(uint64_t)}, + {DT_RESOURCE, sizeof(uint64_t)}, +}; + +const std::map kFmkTypeToString = { + {domi::CAFFE, "caffe"}, + {domi::MINDSPORE, "mindspore"}, + {domi::TENSORFLOW, "tensorflow"}, + {domi::ANDROID_NN, "android_nn"}, + {domi::ONNX, "onnx"}, + {domi::FRAMEWORK_RESERVED, "framework_reserved"}, +}; + +const std::map kImplyTypeToString = { + {domi::ImplyType::BUILDIN, "buildin"}, + {domi::ImplyType::TVM, "tvm"}, + {domi::ImplyType::CUSTOM, "custom"}, + {domi::ImplyType::AI_CPU, "ai_cpu"}, + {domi::ImplyType::CCE, "cce"}, + {domi::ImplyType::GELOCAL, "gelocal"}, + {domi::ImplyType::HCCL, "hccl"}, + {domi::ImplyType::INVALID, "invalid"} +}; +} + + +std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { + auto it = kImplyTypeToString.find(imply_type); + if (it != kImplyTypeToString.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"ImplyTypeToSerialString", + "imply_type[" + std::to_string(static_cast(imply_type)) + "] is not support"}); + GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); + return "UNDEFINED"; + } +} + +bool TypeUtils::IsDataTypeValid(DataType dt) { + uint32_t num = static_cast(dt); + GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); + return true; +} + +std::string TypeUtils::DataTypeToSerialString(DataType data_type) { + auto it = kDataTypeToStringMap.find(data_type); + if (it != kDataTypeToStringMap.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"DataTypeToSerialString", "data_type[" + std::to_string(data_type) + "] is not support"}); + GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type); + return "UNDEFINED"; + } +} + +DataType TypeUtils::SerialStringToDataType(const std::string &str) { + auto it = kStringTodataTypeMap.find(str); + if (it != kStringTodataTypeMap.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SerialStringToDataType", "data_type[" + str + "] is not support"}); + GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str()); + return DT_UNDEFINED; + } +} + +bool TypeUtils::IsFormatValid(Format format) { + uint32_t num = static_cast(format); + GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid"); + return true; +} + +bool TypeUtils::IsDataTypeValid(std::string dt) { + transform(dt.begin(), dt.end(), dt.begin(), ::toupper); + std::string key = "DT_" + dt; + auto it = kStringTodataTypeMap.find(key); + if (it == kStringTodataTypeMap.end()) { + return false; + } + return true; +} + +bool TypeUtils::IsFormatValid(std::string format) { + transform(format.begin(), format.end(), format.begin(), ::toupper); + auto it = kStringToFormatMap.find(format); + if (it == kStringToFormatMap.end()) { + return false; + } + return true; +} + +bool TypeUtils::IsInternalFormat(Format format) { + std::string serial_format = FormatToSerialString(format); + auto iter = kInternalFormat.find(serial_format); + bool result = (iter == kInternalFormat.end()) ? false : true; + return result; +} + +std::string TypeUtils::FormatToSerialString(Format format) { + auto it = kFormatToStringMap.find(format); + if (it != kFormatToStringMap.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"FormatToSerialString", "Format[" + std::to_string(format) + "] is not support"}); + GELOGE(GRAPH_FAILED, "Format not support %u", format); + return "RESERVED"; + } +} +Format TypeUtils::SerialStringToFormat(const std::string &str) { + auto it = kStringToFormatMap.find(str); + if (it != kStringToFormatMap.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"SerialStringToFormat", "Format[" + str + "] is not support"}); + GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); + return FORMAT_RESERVED; + } +} + +Format TypeUtils::DataFormatToFormat(const std::string &str) { + auto it = kDataFormatMap.find(str); + if (it != kDataFormatMap.end()) { + return it->second; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"FormatToSerialString", "Format[" + str + "] is not support"}); + GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); + return FORMAT_RESERVED; + } +} + +Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { + auto it = kDomiFormatToGeFormat.find(domi_format); + if (it != kDomiFormatToGeFormat.end()) { + return it->second; + } + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"FormatToSerialString", "do not find domi Format[" + std::to_string(domi_format) + "] from map"}); + GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); + return FORMAT_RESERVED; +} + +std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) { + auto it = kFmkTypeToString.find(fmk_type); + if (it != kFmkTypeToString.end()) { + return it->second; + } else { + GELOGW("Framework type not support %d.", fmk_type); + return ""; + } +} + +static inline void CopyDataFromBuffer(vector &data, const Buffer &buffer) { + data.clear(); + if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { + data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize()); + } +} + +graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) { + def.scale_mode = uint32_t(usr.scale_mode); + def.set_scale_value(usr.scale_value.data(), usr.scale_value.size()); + def.scale_offset = usr.scale_offset; + def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size()); + def.offset_data_offset = usr.offset_data_offset; + def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size()); + def.offset_weight_offset = usr.offset_weight_offset; + def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size()); + def.offset_pad_offset = usr.offset_pad_offset; + return GRAPH_SUCCESS; +} +graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) { + usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode); + CopyDataFromBuffer(usr.scale_value, def.scale_value); + usr.scale_offset = def.scale_offset; + CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value); + usr.offset_data_offset = def.offset_data_offset; + CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value); + usr.offset_weight_offset = def.offset_weight_offset; + CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value); + usr.offset_pad_offset = def.offset_pad_offset; + return GRAPH_SUCCESS; +} +graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) { + def.set_offsetw(usr.offsetw.data(), usr.offsetw.size()); + def.offsetw_offset = usr.offsetw_offset; + def.set_offsetd(usr.offsetd.data(), usr.offsetd.size()); + def.offsetd_offset = usr.offsetd_offset; + def.set_scalereq(usr.scalereq.data(), usr.scalereq.size()); + def.scaledreq_offset = usr.scaledreq_offset; + def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size()); + def.offsetdnext_offset = usr.offsetdnext_offset; + return GRAPH_SUCCESS; +} +graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) { + CopyDataFromBuffer(usr.offsetw, def.offsetw); + usr.offsetw_offset = def.offsetw_offset; + CopyDataFromBuffer(usr.offsetd, def.offsetd); + usr.offsetd_offset = def.offsetd_offset; + CopyDataFromBuffer(usr.scalereq, def.scalereq); + usr.scaledreq_offset = def.scaledreq_offset; + CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext); + usr.offsetdnext_offset = def.offsetdnext_offset; + return GRAPH_SUCCESS; +} +graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) { + def.quantize_algo = uint32_t(usr.quantize_algo); + def.scale_type = uint32_t(usr.scale_type); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param), + "Usr2DefQuantizeFactor quantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param), + "Usr2DefQuantizeFactor dequantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param), + "Usr2DefQuantizeFactor requantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param), + "Usr2DefQuantizeFactor quantizecalc_param failed"); + return GRAPH_SUCCESS; +} +graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) { + usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo); + usr.scale_type = UsrQuantizeScaleType(def.scale_type); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param), + "Def2UsrQuantizeFactor quantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param), + "Def2UsrQuantizeFactor dequantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param), + "Def2UsrQuantizeFactor requantize_param failed"); + GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param), + "Def2UsrQuantizeCalcFactor quantizecalc_param failed"); + return GRAPH_SUCCESS; +} +bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { + auto it = kDataTypeToLength.find(data_type); + if (it != kDataTypeToLength.end()) { + length = it->second; + return true; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"GetDataTypeLength", "data_type[" + std::to_string(data_type) + "] is not support"}); + GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); + return false; + } +} +bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) { + // Not overflow + if (a == 0) { + return false; + } + if (b <= (ULLONG_MAX / a)) { + return false; + } + return true; +} +} // namespace ge diff --git a/metadef/inc/common/blocking_queue.h b/metadef/inc/common/blocking_queue.h new file mode 100644 index 00000000..8d6c4ef2 --- /dev/null +++ b/metadef/inc/common/blocking_queue.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_BLOCKING_QUEUE_H_ +#define INC_COMMON_BLOCKING_QUEUE_H_ + +#include +#include +#include +#include + +static const int kDefaultMaxQueueSize = 2048; + +template +class BlockingQueue { + public: + explicit BlockingQueue(uint32_t max_size = kDefaultMaxQueueSize) : max_size_(max_size), is_stoped_(false) {} + + ~BlockingQueue() {} + + bool Pop(T &item) { + std::unique_lock lock(mutex_); + + while (queue_.empty() && !is_stoped_) { + empty_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + item = std::move(queue_.front()); + queue_.pop_front(); + + full_cond_.notify_one(); + + return true; + } + + bool Push(const T &item, bool is_wait = true) { + std::unique_lock lock(mutex_); + + while (queue_.size() >= max_size_ && !is_stoped_) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.push_back(item); + + empty_cond_.notify_one(); + + return true; + } + + bool Push(T &&item, bool is_wait = true) { + std::unique_lock lock(mutex_); + + while (queue_.size() >= max_size_ && !is_stoped_) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.emplace_back(std::move(item)); + + empty_cond_.notify_one(); + + return true; + } + + void Stop() { + { + std::unique_lock lock(mutex_); + is_stoped_ = true; + } + + full_cond_.notify_all(); + empty_cond_.notify_all(); + } + + void Restart() { + std::unique_lock lock(mutex_); + is_stoped_ = false; + } + + // if the queue is stoped ,need call this function to release the unprocessed items + std::list GetRemainItems() { + std::unique_lock lock(mutex_); + + if (!is_stoped_) { + return std::list(); + } + + return queue_; + } + + bool IsFull() { + std::unique_lock lock(mutex_); + return queue_.size() >= max_size_; + } + + void Clear() { + std::unique_lock lock(mutex_); + queue_.clear(); + } + + private: + std::list queue_; + std::mutex mutex_; + std::condition_variable empty_cond_; + std::condition_variable full_cond_; + uint32_t max_size_; + + bool is_stoped_; +}; + +#endif // INC_COMMON_BLOCKING_QUEUE_H_ diff --git a/metadef/inc/common/dynamic_aipp.h b/metadef/inc/common/dynamic_aipp.h new file mode 100644 index 00000000..9ada1ef5 --- /dev/null +++ b/metadef/inc/common/dynamic_aipp.h @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_DYNAMIC_AIPP_H_ +#define INC_COMMON_DYNAMIC_AIPP_H_ + +#include + +/** +* @ingroup dnn +* @brief struct define of dynamic aipp batch parameter. +*/ +typedef struct tagAippDynamicBatchPara { + int8_t cropSwitch; // crop switch + int8_t scfSwitch; // resize switch + int8_t paddingSwitch; // 0: unable padding + // 1: padding config value,sfr_filling_hblank_ch0 ~ sfr_filling_hblank_ch2 + // 2: padding source picture data, single row/collumn copy + // 3: padding source picture data, block copy + // 4: padding source picture data, mirror copy + int8_t rotateSwitch; // rotate switch,0: non-ratate, + // 1: ratate 90° clockwise,2: ratate 180° clockwise,3: ratate 270° clockwise + int8_t reserve[4]; + int32_t cropStartPosW; // the start horizontal position of cropping + int32_t cropStartPosH; // the start vertical position of cropping + int32_t cropSizeW; // crop width + int32_t cropSizeH; // crop height + + int32_t scfInputSizeW; // input width of scf + int32_t scfInputSizeH; // input height of scf + int32_t scfOutputSizeW; // output width of scf + int32_t scfOutputSizeH; // output height of scf + + int32_t paddingSizeTop; // top padding size + int32_t paddingSizeBottom; // bottom padding size + int32_t paddingSizeLeft; // left padding size + int32_t paddingSizeRight; // right padding size + + int16_t dtcPixelMeanChn0; // mean value of channel 0 + int16_t dtcPixelMeanChn1; // mean value of channel 1 + int16_t dtcPixelMeanChn2; // mean value of channel 2 + int16_t dtcPixelMeanChn3; // mean value of channel 3 + + uint16_t dtcPixelMinChn0; // min value of channel 0 + uint16_t dtcPixelMinChn1; // min value of channel 1 + uint16_t dtcPixelMinChn2; // min value of channel 2 + uint16_t dtcPixelMinChn3; // min value of channel 3 + uint16_t dtcPixelVarReciChn0; // sfr_dtc_pixel_variance_reci_ch0 + uint16_t dtcPixelVarReciChn1; // sfr_dtc_pixel_variance_reci_ch1 + uint16_t dtcPixelVarReciChn2; // sfr_dtc_pixel_variance_reci_ch2 + uint16_t dtcPixelVarReciChn3; // sfr_dtc_pixel_variance_reci_ch3 + + int8_t reserve1[16]; // 32B assign, for ub copy +} kAippDynamicBatchPara; + +/** +* @ingroup dnn +* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte +*/ +typedef struct tagAippDynamicPara { + uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 + int8_t cscSwitch; // csc switch + int8_t rbuvSwapSwitch; // rb/ub swap switch + int8_t axSwapSwitch; // RGBA->ARGB, YUVA->AYUV swap switch + int8_t batchNum; // batch parameter number + int8_t reserve1[3]; + int32_t srcImageSizeW; // source image width + int32_t srcImageSizeH; // source image height + int16_t cscMatrixR0C0; // csc_matrix_r0_c0 + int16_t cscMatrixR0C1; // csc_matrix_r0_c1 + int16_t cscMatrixR0C2; // csc_matrix_r0_c2 + int16_t cscMatrixR1C0; // csc_matrix_r1_c0 + int16_t cscMatrixR1C1; // csc_matrix_r1_c1 + int16_t cscMatrixR1C2; // csc_matrix_r1_c2 + int16_t cscMatrixR2C0; // csc_matrix_r2_c0 + int16_t cscMatrixR2C1; // csc_matrix_r2_c1 + int16_t cscMatrixR2C2; // csc_matrix_r2_c2 + int16_t reserve2[3]; + uint8_t cscOutputBiasR0; // output Bias for RGB to YUV, element of row 0, unsigned number + uint8_t cscOutputBiasR1; // output Bias for RGB to YUV, element of row 1, unsigned number + uint8_t cscOutputBiasR2; // output Bias for RGB to YUV, element of row 2, unsigned number + uint8_t cscInputBiasR0; // input Bias for YUV to RGB, element of row 0, unsigned number + uint8_t cscInputBiasR1; // input Bias for YUV to RGB, element of row 1, unsigned number + uint8_t cscInputBiasR2; // input Bias for YUV to RGB, element of row 2, unsigned number + uint8_t reserve3[2]; + int8_t reserve4[16]; // 32B assign, for ub copy + + kAippDynamicBatchPara aippBatchPara; // allow transfer several batch para. +} kAippDynamicPara; + +#endif // INC_COMMON_DYNAMIC_AIPP_H_ diff --git a/metadef/inc/common/npu_error_define.h b/metadef/inc/common/npu_error_define.h new file mode 100644 index 00000000..aba70f99 --- /dev/null +++ b/metadef/inc/common/npu_error_define.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_NPU_ERROR_DEFINE_H_ +#define INC_COMMON_NPU_ERROR_DEFINE_H_ + +typedef enum tagHiAiNpuLocal { + HIAI_HOST = 1, + HIAI_DEVICE = 2, +} HiAiNpuLocal; + +typedef enum tagHiAiNpuCodeType { + ERROR_CODE = 1, + EXCEPTION_CODE = 2, +} HiAiNpuCodeType; + +typedef enum tagHiAiNpuErrLevel { + NONE_LEVEL = 0, + SUGGESTION_LEVEL = 1, + NORMAL_LEVEL = 2, + SERIOUS_LEVEL = 3, + CRITICAL_ERROR = 4, +} HiAiNpuErrLevel; + +typedef enum tagHiAiNpuModuleId { + HIAI_DRIVER = 1, + HIAI_CTRLCPU = 2, + HIAI_TS = 3, + HIAI_RUNTIME = 4, + HIAI_AICPU = 5, + HIAI_CCE = 6, + HIAI_TVM = 7, + HIAI_FRAMEWORK = 8, + HiAI_ENGINE = 9, + HIAI_DVPP = 10, + HIAI_AIPP = 11, + HIAI_LOWPOWER = 12, + HIAI_MDC = 13, + HIAI_COMPILE = 14, + HIAI_TOOLCHIAN = 15, + HIAI_ALG = 16, + HIAI_PROFILING = 17, + HIAI_HCCL = 18, + HIAI_SIMULATION = 19, + HIAI_BIOS = 20, + HIAI_SEC = 21, + HIAI_TINY = 22, + HIAI_DP = 23, +} HiAiNpuModuleId; + +/* bit 31-bit30 to be hiai local */ +#define HIAI_NPULOCAL_MASK 0xC0000000 +#define SHIFT_LOCAL_MASK 30 +#define HIAI_NPULOCAL_VAL_MASK 0x3 +/* bit 29 -bit28 to be hiai aicpu code type */ +#define HIAI_CODE_TYPE_MASK 0x30000000 +#define SHIFT_CODE_MASK 28 +#define HIAI_CODE_TYPE_VAL_MASK 0x3 +/* bit 27 -bit25 to be hiai error level */ +#define HIAI_ERROR_LEVEL_MASK 0x0E000000 +#define SHIFT_ERROR_LVL_MASK 25 +#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 +/* bit 24 -bit17 to be hiai mod */ +#define HIAI_MODE_ID_MASK 0x01FE0000 +#define SHIFT_MODE_MASK 17 +#define HIAI_MODE_ID_VAL_MASK 0xFF + +#define HIAI_NPU_LOC_BIT(a) \ + (HIAI_NPULOCAL_MASK & ((unsigned int)((HiAiNpuLocal)(a)) & HIAI_NPULOCAL_VAL_MASK) << SHIFT_LOCAL_MASK) +#define HIAI_NPU_CODE_TYPE_BIT(a) \ + (HIAI_CODE_TYPE_MASK & ((unsigned int)((HiAiNpuCodeType)(a)) & HIAI_CODE_TYPE_VAL_MASK) << SHIFT_CODE_MASK) +#define HIAI_NPU_ERR_LEV_BIT(a) \ + (HIAI_ERROR_LEVEL_MASK & ((unsigned int)((HiAiNpuErrLevel)(a)) & HIAI_ERROR_LEVEL_VAL_MASK) << SHIFT_ERROR_LVL_MASK) +#define HIAI_NPU_MOD_ID_BIT(a) \ + (HIAI_MODE_ID_MASK & ((unsigned int)((HiAiNpuModuleId)(a)) & HIAI_MODE_ID_VAL_MASK) << SHIFT_MODE_MASK) + +#define HIAI_NPU_ERR_CODE_HEAD(npuLocal, codeType, errLevel, moduleId) \ + (HIAI_NPU_LOC_BIT(npuLocal) + HIAI_NPU_CODE_TYPE_BIT(codeType) + HIAI_NPU_ERR_LEV_BIT(errLevel) + \ + HIAI_NPU_MOD_ID_BIT(moduleId)) + +#endif // INC_COMMON_NPU_ERROR_DEFINE_H_ diff --git a/metadef/inc/common/opskernel/ge_task_info.h b/metadef/inc/common/opskernel/ge_task_info.h new file mode 100644 index 00000000..145f3f27 --- /dev/null +++ b/metadef/inc/common/opskernel/ge_task_info.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ +#define INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ + +#include +#include +#include +#include + +using std::string; +namespace ge { +// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD +struct GETaskKernelHcclInfo { + string input_name; + string hccl_type; + void *inputDataAddr; + void *outputDataAddr; + void *workSpaceAddr; + int32_t count; + int32_t dataType; + int32_t opType; + int64_t rootId; + uint64_t workSpaceMemSize; + std::vector dims; + std::vector hcclStreamList; +}; + +struct GETaskInfo { + uint32_t id; + uint16_t type; + uint32_t streamID; + void *stream; // rtKernelLaunch input argument + void *event; + void *privateDef; + uint32_t privateDefLen; + void *opsKernelStorePtr; + + std::vector kernelHcclInfo; +}; + +struct HcomOpertion { + std::string hcclType; + void *inputPtr; + void *outputPtr; + uint64_t count; + int32_t dataType; + int32_t opType; + int32_t root; +}; + +struct HcomRemoteAccessAddrInfo +{ + uint32_t remotetRankID; + uint64_t remoteAddr; // host embedding table address + uint64_t localAddr; // device HBM address + uint64_t length; // memory Length in Bytes +}; + + +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/metadef/inc/common/opskernel/ops_kernel_builder.h b/metadef/inc/common/opskernel/ops_kernel_builder.h new file mode 100644 index 00000000..169f27f8 --- /dev/null +++ b/metadef/inc/common/opskernel/ops_kernel_builder.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ +#define INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "cce/aicpu_engine_struct.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/node.h" +#include "proto/task.pb.h" + +namespace ge { +class OpsKernelBuilder { + public: + OpsKernelBuilder() = default; + virtual ~OpsKernelBuilder() = default; + + // initialize OpsKernelBuilder + virtual Status Initialize(const std::map &options) = 0; + + // finalize OpsKernelBuilder + virtual Status Finalize() = 0; + + // memory allocation requirement + virtual Status CalcOpRunningParam(Node &node) = 0; + + // generate task for op + virtual Status GenerateTask(const Node &node, RunContext &context, + std::vector &tasks) = 0; + + // only call aicpu interface to generate task struct + virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { + return FAILED; + } + + // only call aicpu interface to generate task struct + virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { + return FAILED; + } +}; +} // namespace ge +#endif // INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ diff --git a/metadef/inc/common/opskernel/ops_kernel_info_store.h b/metadef/inc/common/opskernel/ops_kernel_info_store.h new file mode 100644 index 00000000..330c67b6 --- /dev/null +++ b/metadef/inc/common/opskernel/ops_kernel_info_store.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ + +#include +#include +#include +#include +#include "./ge_task_info.h" +#include "./ops_kernel_info_types.h" +#include "cce/aicpu_engine_struct.h" +#include "cce/fwk_adpt_struct.h" +#include "common/ge_inner_error_codes.h" +#include "graph/node.h" +#include "proto/task.pb.h" +using std::map; +using std::string; +using std::to_string; +using std::vector; + +namespace ge { +class OpDesc; + +class OpsKernelInfoStore { + public: + OpsKernelInfoStore() {} + + virtual ~OpsKernelInfoStore() {} + + // initialize opsKernelInfoStore + virtual Status Initialize(const map &options) = 0; /*lint -e148*/ + + // close opsKernelInfoStore + virtual Status Finalize() = 0; /*lint -e148*/ + + virtual Status CreateSession(const std::map &session_options) { return SUCCESS; } + + virtual Status DestroySession(const std::map &session_options) { return SUCCESS; } + + // get all opsKernelInfo + virtual void GetAllOpsKernelInfo(map &infos) const = 0; + + // whether the opsKernelInfoStore is supported based on the operator attribute + virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; + + virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, + bool realQuery = false) const { + return CheckSupported(opDescPtr, un_supported_reason); + } + // opsFlag opsFlag[0] indicates constant folding is supported or not + virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag) {}; + + // only call fe engine interface to compile single op + virtual Status CompileOp(vector &node_vec) { return SUCCESS; } + virtual Status CompileOpRun(vector &node_vec) { return SUCCESS; } + // load task for op + virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } +}; +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ diff --git a/metadef/inc/common/opskernel/ops_kernel_info_types.h b/metadef/inc/common/opskernel/ops_kernel_info_types.h new file mode 100644 index 00000000..8207151b --- /dev/null +++ b/metadef/inc/common/opskernel/ops_kernel_info_types.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ + +#include +#include +#include +#include "graph/buffer.h" +#include "runtime/rt_model.h" + +using std::string; + +namespace ge { +/*lint -e148*/ +struct RunContext { + rtModel_t model; + rtStream_t stream; + uint64_t sessionId; + uint64_t dataMemSize; + uint8_t *dataMemBase; + std::map mem_type_data_mem_size; + std::map mem_type_data_mem_base; + uint64_t weightMemSize; + uint8_t *weightMemBase; + ge::Buffer weightsBuffer; + std::vector graphStreamList; // all streams of graph, order by ge stream id(0,1,...) + std::vector graphEventList; // all events of graph, order by ge event id(0,1,...) + std::vector graphLabelList; // all labels of graph, order by ge label id(0,1,...) +}; + +/*lint +e148*/ +struct Task { + uint32_t id; + uint16_t type; + void *stream; + void *event; +}; + +struct OpInfo { + string engine; // which engin + /*lint -e148*/ + string opKernelLib; // which opsKernelStore + int computeCost; // compute cost + bool flagPartial; // whether to support is related to shape + bool flagAsync; // Whether to support asynchronous + bool isAtomic; // whether to support atomic addr clean + string opFileName; // op file name + string opFuncName; // op function name +}; +} // namespace ge + +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ diff --git a/metadef/inc/common/optimizer/graph_optimizer.h b/metadef/inc/common/optimizer/graph_optimizer.h new file mode 100644 index 00000000..9865df14 --- /dev/null +++ b/metadef/inc/common/optimizer/graph_optimizer.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ + +#include +#include +#include "./graph_optimizer_types.h" +#include "common/ge_inner_error_codes.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/compute_graph.h" + +using std::map; +using std::string; + +/*lint -e148*/ +namespace ge { +class GraphOptimizer { + public: + virtual ~GraphOptimizer() {} + + // initialize graphOptimizer + virtual Status Initialize(const map &options) = 0; + + // close graphOptimizer + virtual Status Finalize() = 0; + + // optimize original graph for FE quant optimize + virtual Status OptimizeGraphPrepare(ComputeGraph& graph) { + return SUCCESS; + } + + // optimize graph before build for RTS + virtual Status OptimizeGraphBeforeBuild(ComputeGraph& graph) { + return SUCCESS; + } + + // optimize original graph, using in graph preparation stage + virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; + + // optimize original graph, using for conversion operator insert in graph preparation stage + virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { + return SUCCESS; + } + + // optimize fused graph + virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; + + // optimize whole graph, using after graph merged stage + virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; + + // get attribute of graph optimizer + virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; + + // optimize streamed Graph + virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } + + // op compile + virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } +}; +} // namespace ge +/*lint +e148*/ +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/metadef/inc/common/optimizer/graph_optimizer_types.h b/metadef/inc/common/optimizer/graph_optimizer_types.h new file mode 100644 index 00000000..9e1ec96b --- /dev/null +++ b/metadef/inc/common/optimizer/graph_optimizer_types.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ + +#include +#include +namespace ge { +enum OPTIMIZER_SCOPE { + UNIT = 0, + ENGINE, +}; + +struct GraphOptimizerAttribute { + std::string engineName; + OPTIMIZER_SCOPE scope; +}; +} // namespace ge + +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ diff --git a/metadef/inc/common/proto/dump_task.proto b/metadef/inc/common/proto/dump_task.proto new file mode 100644 index 00000000..b1e346cd --- /dev/null +++ b/metadef/inc/common/proto/dump_task.proto @@ -0,0 +1,111 @@ +syntax = "proto3"; +package toolkit.dumpdata; + +enum OutputDataType { + DT_UNDEFINED = 0; + DT_FLOAT = 1; + DT_FLOAT16 = 2; + DT_INT8 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_UINT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT32 = 9; + DT_UINT64 = 10; + DT_BOOL = 11; + DT_DOUBLE = 12; + DT_STRING = 13; + DT_DUAL_SUB_INT8 = 14; + DT_DUAL_SUB_UINT8 = 15; + DT_COMPLEX64 = 16; + DT_COMPLEX128 = 17; + DT_QINT8 = 18; + DT_QINT16 = 19; + DT_QINT32 = 20; + DT_QUINT8 = 21; + DT_QUINT16 = 22; + DT_RESOURCE = 23; + DT_STRING_REF = 24; + DT_DUAL = 25; +} + +enum OutputFormat { + FORMAT_NCHW = 0; + FORMAT_NHWC = 1; + FORMAT_ND = 2; + FORMAT_NC1HWC0 = 3; + FORMAT_FRACTAL_Z = 4; + FORMAT_NC1C0HWPAD = 5; + FORMAT_NHWC1C0 = 6; + FORMAT_FSR_NCHW = 7; + FORMAT_FRACTAL_DECONV = 8; + FORMAT_C1HWNC0 = 9; + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; + FORMAT_NC1HWC0_C04 = 12; + FORMAT_FRACTAL_Z_C04 = 13; + FORMAT_CHWN = 14; + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; + FORMAT_HWCN = 16; + FORMAT_NC1KHKWHWC0 = 17; + FORMAT_BN_WEIGHT = 18; + FORMAT_FILTER_HWCK = 19; + FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; + FORMAT_HASHTABLE_LOOKUP_KEYS = 21; + FORMAT_HASHTABLE_LOOKUP_VALUE = 22; + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; + FORMAT_HASHTABLE_LOOKUP_HITS=24; + FORMAT_C1HWNCoC0 = 25; + FORMAT_MD = 26; + FORMAT_NDHWC = 27; + FORMAT_FRACTAL_ZZ = 28; + FORMAT_FRACTAL_NZ = 29; + FORMAT_RESERVED = 30; +} + +message OriginalOp { + string name = 1; + uint32 output_index = 2; + OutputDataType data_type = 3; + OutputFormat format = 4; +} + +message Shape { + repeated uint64 dim = 1; +} + +message OpOutput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + OriginalOp original_op = 4; // the original op corresponding to the output + bytes data = 5; + uint64 size = 6; +} + +message OpInput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + bytes data = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + bytes data = 2; + uint64 size = 3; +} + +message DumpData{ + string version = 1; + uint64 dump_time = 2; + repeated OpOutput output = 3; + repeated OpInput input = 4; + repeated OpBuffer buffer = 5; +} diff --git a/metadef/inc/common/proto/fusion_model.proto b/metadef/inc/common/proto/fusion_model.proto new file mode 100644 index 00000000..c92c5581 --- /dev/null +++ b/metadef/inc/common/proto/fusion_model.proto @@ -0,0 +1,21 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +import "om.proto"; + +package domi; + +message FusionModelDef { + string version = 1; + repeated OpDef fusion_op = 2; +} \ No newline at end of file diff --git a/metadef/inc/common/proto/fwk_adapter.proto b/metadef/inc/common/proto/fwk_adapter.proto new file mode 100644 index 00000000..9335c926 --- /dev/null +++ b/metadef/inc/common/proto/fwk_adapter.proto @@ -0,0 +1,37 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package aicpu.FWKAdapter; +option cc_enable_arenas = true; + + +// Defines an struct for input and output. +message TensorDataInfo { + + // value DataType + uint32 dtype = 1; + + // shape dim + repeated int64 dim = 2; + + // data point addr + int64 data_addr = 3; +} + +message KernelRunParam { + // input + repeated TensorDataInfo input = 1; + // output + repeated TensorDataInfo output = 2; +} + diff --git a/metadef/inc/common/proto/ge_ir.proto b/metadef/inc/common/proto/ge_ir.proto new file mode 100644 index 00000000..e7bfe0cb --- /dev/null +++ b/metadef/inc/common/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/metadef/inc/common/proto/insert_op.proto b/metadef/inc/common/proto/insert_op.proto new file mode 100644 index 00000000..bf918b20 --- /dev/null +++ b/metadef/inc/common/proto/insert_op.proto @@ -0,0 +1,139 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // related_input_name is optional and the top name of data node which inserts aipp + string related_input_name = 6; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/metadef/inc/common/proto/om.proto b/metadef/inc/common/proto/om.proto new file mode 100644 index 00000000..e15e5f80 --- /dev/null +++ b/metadef/inc/common/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/metadef/inc/common/proto/op_mapping_info.proto b/metadef/inc/common/proto/op_mapping_info.proto new file mode 100644 index 00000000..e23b7ebe --- /dev/null +++ b/metadef/inc/common/proto/op_mapping_info.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/metadef/inc/common/proto/proto_inner/ge_onnx.proto b/metadef/inc/common/proto/proto_inner/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/inc/common/proto/proto_inner/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/inc/common/proto/task.proto b/metadef/inc/common/proto/task.proto new file mode 100644 index 00000000..d0c09840 --- /dev/null +++ b/metadef/inc/common/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/metadef/inc/common/util/ai_core/aicore_manager/aicore_util_manager.h b/metadef/inc/common/util/ai_core/aicore_manager/aicore_util_manager.h new file mode 100644 index 00000000..df0728e1 --- /dev/null +++ b/metadef/inc/common/util/ai_core/aicore_manager/aicore_util_manager.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef AICORE_UTIL_MANAGER_H_ +#define AICORE_UTIL_MANAGER_H_ + +#include +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +class AICoreUtilManager { + public: + static AICoreUtilManager &Instance(); + /* + * to initialize the aicore configuration + * param[in] the options of init + * param[in] engine Name + * param[in] socVersion soc version from ge + * return Status(SUCCESS/FAILED) + */ + Status Initialize(const std::map &options, std::string &soc_version); + + /* + * to release the source of fusion manager + * return Status(SUCCESS/FAILED) + */ + Status Finalize(); + + private: + AICoreUtilManager(); + ~AICoreUtilManager(); + bool is_init_; +}; +} // namespace fe +#endif // AICORE_UTIL_MANAGER_H \ No newline at end of file diff --git a/metadef/inc/common/util/ai_core/common/aicore_util_attr_define.h b/metadef/inc/common/util/ai_core/common/aicore_util_attr_define.h new file mode 100644 index 00000000..237d9d0e --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/aicore_util_attr_define.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_ATTR_DEFINE_H_ + +#include + +namespace fe { +static const std::string SCOPE_ID_ATTR = "fusion_scope"; + +static const std::string FE_IMPLY_TYPE = "_fe_imply_type"; + +static const std::string PARENT_OP_TYPE = "parentOpType"; + +static const std::string ATTR_NAME_TASK_L2_FUSION_INFO_EXTEND_PTR = "task_l2_fusion_info_extend_content"; + +static const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; + +static const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_content"; + +static const std::string L1_OPTIMIZED = "l1_optimized"; + +static const std::string L2_OPTIMIZED = "l2_optimized"; + +static const std::string ATTR_NAME_UNKNOWN_SHAPE = "_unknown_shape"; + +static const std::string ATTR_NAME_IS_UNKNOWN_GRAPH = "_fe_is_unknown_graph"; + +static const std::string ATTR_NAME_IS_UNKNOWN_SHAPE_OP = "_fe_is_unknown_shape_op"; + +static const std::string ATTR_NAME_TVM_CACHE_READ_MODE = "tvm_cache_read_mode"; + +static const std::string ATTR_NAME_TBE_KERNEL_SIZE = "_tbeKernelSize"; +} // namespace fe +#endif diff --git a/metadef/inc/common/util/ai_core/common/aicore_util_constants.h b/metadef/inc/common/util/ai_core/common/aicore_util_constants.h new file mode 100644 index 00000000..1deced90 --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/aicore_util_constants.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_ + +#include + +namespace fe { +static const std::string CORE_TYPE = "_coretype"; +/* engine name of AI core and vector core */ +static const std::string AI_CORE_NAME = "AIcoreEngine"; +static const std::string VECTOR_CORE_NAME = "VectorEngine"; + +static const int64_t IS_UNKNOWN_SHAPE_VALUE = 1; + +static const int64_t SHAPE_UNKNOWN_DIM = -1; + +static const int64_t SHAPE_UNKNOWN_DIM_NUM = -2; + +static const std::string SOC_VERSION_ASCEND310 = "Ascend310"; +static const std::string SOC_VERSION_ASCEND610 = "Ascend610"; +static const std::string SOC_VERSION_ASCEND615 = "Ascend615"; +static const std::string SOC_VERSION_ASCEND710 = "Ascend710"; +static const std::string SOC_VERSION_ASCEND710P = "Ascend710Pro"; +static const std::string SOC_VERSION_ASCEND910A = "Ascend910A"; +static const std::string SOC_VERSION_ASCEND910B = "Ascend910B"; +static const std::string SOC_VERSION_ASCEND910PROA = "Ascend910ProA"; +static const std::string SOC_VERSION_ASCEND910PROB = "Ascend910ProB"; +static const std::string SOC_VERSION_ASCEND910PREMIUMA = "Ascend910PremiumA"; +static const std::string SOC_VERSION_HI3796CV300ES = "Hi3796CV300ES"; +static const std::string SOC_VERSION_HI3796CV300CS = "Hi3796CV300CS"; + +static const std::vector SOC_VERSION_CLOUD_LIST = { + SOC_VERSION_ASCEND910A, SOC_VERSION_ASCEND910B, SOC_VERSION_ASCEND910PROA, + SOC_VERSION_ASCEND910PROB, SOC_VERSION_ASCEND910PREMIUMA +}; + +static const std::vector SOC_VERSION_DC_LIST = {SOC_VERSION_ASCEND610, SOC_VERSION_ASCEND615, + SOC_VERSION_ASCEND710, SOC_VERSION_ASCEND710P}; +} // namespace fe +#endif diff --git a/metadef/inc/common/util/ai_core/common/aicore_util_types.h b/metadef/inc/common/util/ai_core/common/aicore_util_types.h new file mode 100644 index 00000000..eeebb653 --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/aicore_util_types.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_TYPES_H_ + +#include "graph/anchor.h" +#include "graph/types.h" +#include "runtime/kernel.h" +#include +#include +#include + +namespace fe { +struct FusionOpSrc { + uint32_t src_op_id; + ge::AnchorPtr src_anchor; + int32_t fusion_src_index; + int32_t fusion_dst_index; +}; + +struct FusionOpDst { + uint32_t dst_op_id; + ge::AnchorPtr dst_anchor; +}; + +struct FusionDataFlow { + std::pair edge; + std::pair node_dataindex_pair; +}; + +typedef struct tag_l2_fusion_data { + uint32_t l2Index; + uint64_t l2Addr; + uint64_t l2PageNum; +} L2FusionData_t; +typedef std::map L2FusionDataMap_t; + +typedef struct tag_fe_sm_desc { + rtL2Ctrl_t l2ctrl; + std::string node_name[8]; + uint8_t output_index[8]; +} fe_sm_desc_t; + +typedef struct TagTaskL2FusionInfo { + std::string node_name; + fe_sm_desc_t l2_info; + L2FusionDataMap_t input; + L2FusionDataMap_t output; + uint32_t is_used; +} TaskL2FusionInfo_t; + +using L2FusionInfoPtr = std::shared_ptr; + +typedef struct ToOpStruct { + int64_t op_l1_space = 0; + std::vector op_l1_fusion_type; + int64_t op_l1_workspace_flag = 0; // for workspace flag + int64_t op_l1_workspace_size = 0; + std::vector> valid_input_shape; + std::vector> valid_output_shape; + std::vector> + slice_input_offset; // conv & pooling & ReadSelect + std::vector> slice_output_offset; // WriteSelect + std::vector total_shape; + uint32_t split_index = 0; + ToOpStruct() { + // set invalid value for essential variable + op_l1_space = -1; + op_l1_workspace_size = -1; + } +} ToOpStruct_t; + +enum SlicePattern { + ELEMENT_WISE = 0, + ELEMENT_WISE_BROADCAST, + BROADCAST, + SLIDING_WINDOW, + SLIDING_WINDOW_DECONV, + CUBE_MATMUL, + SLICE_PATTERN_REDUCE, + SLICE_PATTERN_RESIZE, + SLICE_PATTERN_SCATTER, + SLICE_PATTERN_SEGMENT, + PATTERN_RESERVED +}; + +enum OpImplType { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_RESERVED // reserved value +}; + +// Dont change the order, only add new mode in the end +enum L2Mode { EN_L2_CLOSE = 0, EN_L2_BUFFER_OPTIMIZE, EN_L2_CACHE_NORMAL, EN_L2_CACHE_RC }; +enum BufferFusionMode { EN_OPTIMIZE_DISABLE = 0, EN_L2_BUFFER, EN_L2_FUSION}; + +static const std::map DATATYPE_SIZE_MAP{ + {ge::DT_FLOAT, sizeof(float)}, + {ge::DT_FLOAT16, sizeof(int16_t)}, + {ge::DT_INT8, sizeof(int8_t)}, + {ge::DT_INT32, sizeof(int32_t)}, + {ge::DT_UINT8, sizeof(uint8_t)}, + {ge::DT_UINT32, sizeof(uint32_t)}, + {ge::DT_INT16, sizeof(int16_t)}, + {ge::DT_UINT16, sizeof(uint16_t)}, + {ge::DT_INT64, sizeof(int64_t)}, + {ge::DT_UINT64, sizeof(uint64_t)}, + {ge::DT_DOUBLE, sizeof(double)}, + {ge::DT_BOOL, sizeof(bool)}, + {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, + {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, + {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)} +}; + +enum OpReduceType { + REDUCE_MEAN = 0, + REDUCE_ADD, + REDUCE_MAX, + REDUCE_MIN, +}; + +} +#endif diff --git a/metadef/inc/common/util/ai_core/common/graph_comm.h b/metadef/inc/common/util/ai_core/common/graph_comm.h new file mode 100644 index 00000000..abde4437 --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/graph_comm.h @@ -0,0 +1,128 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_GRAPH_COMMON_H_ + +#include "graph/compute_graph.h" +#include "common/aicore_util_types.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include +#include + +namespace fe { + +using k_scope_node_map_t = std::map>; +using k_scope_node_pair_t = std::pair>; + +class GraphCommImpl; +using GraphCommImplPtr = std::unique_ptr; + +class GraphComm { +public: + GraphComm(const string &engine_name); + virtual ~GraphComm(); + GraphComm(const GraphComm &in) = delete; + GraphComm &operator=(const GraphComm &in) = delete; + + Status GetscopeNodeMap(ge::ComputeGraph &graph, k_scope_node_map_t &fusion_map); + + Status CopyFusionOpNodes(vector &fus_input_edge_list, + vector &fus_output_edge_list, + vector &fus_nodelist, + ge::OpDescPtr fusion_op_desc, + ge::ComputeGraphPtr fusion_graph); + + Status CopyFusionOpEdges(ge::OpDescPtr fusion_op_desc, + ge::ComputeGraph &orig_graph, + ge::ComputeGraphPtr fusion_graph); + + Status GetNodeDataFlowMap( + const ge::NodePtr &fus_node, + std::map> + &fusion_op_anchors_map, + ge::kFusionDataFlowVec_t &fus_dataflow_list, const int &map_type); + + Status GetFusionNodeEdgeList(std::vector &fus_nodelist, + std::vector &fus_input_edge_list, + std::vector &fus_output_edge_list); + void ClearFusionSrc(); + + void ClearFusionDst(); + + void + AddFusionOutputSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, + const int32_t &fusion_src_index, + std::pair &node_dataindex_pair); + + void AddFusionInputSrc(const uint32_t &src_op_id, + const ge::AnchorPtr &src_anchor, + const int32_t &fusion_dst_index, + std::pair &node_dataindex_pair); + + void SaveFusionDst(const uint32_t &dst_op_id, ge::AnchorPtr dst_anchor); + + bool IsFusionDstExist(const uint32_t &dst_op_id, + const ge::AnchorPtr &dst_anchor); + + bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, + int32_t &fusion_src_index, int32_t &fusion_dst_index); + + Status + GetFusionNodeCtrlEdgeList(vector &fus_nodelist, + vector &fus_input_ctrl_edge_list, + vector &fus_output_ctrl_edge_list); + + Status MergeFusionNodeEdgeList(ge::NodePtr &fus_node, + vector &fus_nodelist, + vector &fus_input_edge_list, + vector &fus_output_edge_list); + + Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fus_node, + vector &fus_nodelist, + vector &fus_input_edge_list, + vector &fus_output_edge_list); + + string GetEngineName(); + +private: + Status + MergeFusionNodeInputEdgeList(ge::NodePtr fus_node, + std::vector &fus_nodelist, + std::vector &fus_input_edge_list); + Status + MergeFusionNodeOutputEdgeList(ge::NodePtr fus_node, + std::vector &fus_nodelist, + std::vector &fus_output_edge_list); + + string engine_name_; + + std::vector exist_fusion_src_list_; + std::vector exist_fusion_dst_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_input_dataflow_list_; + + // std::vector> + ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; + + GraphCommImplPtr graph_comm_impl_ptr_; +}; +} // namespace fe +#endif diff --git a/metadef/inc/common/util/ai_core/common/json_util.h b/metadef/inc/common/util/ai_core/common/json_util.h new file mode 100644 index 00000000..5bca3b98 --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/json_util.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PROJECT_JSON_UTIL_H +#define PROJECT_JSON_UTIL_H + +#include "graph/compute_graph.h" + +#include "common/aicore_util_types.h" +#include "fusion_engine/graph_tuner/graph_tuner_errorcode.h" + +const std::string L1_FUSION_EXTEND_CONTENT = "_l1_fusion_extend_content"; +const std::string L2_FUSION_EXTEND_CONTENT = "l2_fusion_extend_content"; +const std::string TASK_L2_FUSION_INFO_EXTEND_CONTENT = "task_l2_fusion_info_extend_content"; +const std::string L1_FUSION_TO_OP_STRUCT = "_l1fusion_ToOpStruct"; +const std::string L2_FUSION_TO_OP_STRUCT = "_l2fusion_ToOpStruct"; +const std::string TASK_L2_FUSION_INFO = "_task_L2FusionInfo"; + +namespace tune { +using ToOpStructPtr = std::shared_ptr; +using L2FusionInfoPtr = std::shared_ptr; + +Status GetL1InfoFromJson(ge::OpDescPtr opDescPtr); + +Status GetL2InfoFromJson(ge::OpDescPtr opDescPtr); + +Status GetTaskL2FusionInfoFromJson(ge::OpDescPtr opDescPtr); + +Status ReadGraphInfoFromJson(ge::ComputeGraph &graph); + +Status WriteGraphInfoToJson(ge::ComputeGraph &graph); + +void GetL2ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l2InfoPtr); + +void GetL1ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l1InfoPtr); + +L2FusionInfoPtr GetL2FusionInfoFromJson(ge::OpDescPtr &opDescPtr); + +void SetL2FusionInfoToNode(ge::OpDescPtr &opDescPtr, L2FusionInfoPtr &l2FusionInfoPtr); +} // namespace tune +#endif //PROJECT_JSON_UTIL_H diff --git a/metadef/inc/common/util/ai_core/common/l2_stream_info.h b/metadef/inc/common/util/ai_core/common/l2_stream_info.h new file mode 100644 index 00000000..8a64eb8e --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/l2_stream_info.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef L2_STREAM_INFO_H_ +#define L2_STREAM_INFO_H_ + +#include +#include +#include +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "runtime/base.h" +#include "cce/l2fusion_struct.hpp" + +namespace fe { +class StreamL2Info { + public: + StreamL2Info(const StreamL2Info &) = delete; + StreamL2Info &operator=(const StreamL2Info &) = delete; + static StreamL2Info& Instance(); + Status GetStreamL2Info(rtStream_t stream_id, string node_name, fusion::TaskL2Info_t *&l2_data); + Status SetStreamL2Info(const rtStream_t &stream_id, fusion::TaskL2InfoFEMap_t &l2_alloc_res); + + private: + StreamL2Info(); + ~StreamL2Info(); + mutable std::mutex stream_l2_mutex_; + std::map stream_l2_map_; +}; +} // namespace fe + +#endif // L2_STREAM_INFO_H_ \ No newline at end of file diff --git a/metadef/inc/common/util/ai_core/common/scope_allocator.h b/metadef/inc/common/util/ai_core/common/scope_allocator.h new file mode 100644 index 00000000..e81282b3 --- /dev/null +++ b/metadef/inc/common/util/ai_core/common/scope_allocator.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ +#define INC_COMMON_UTILS_AI_CORE_COMMON_SCOPE_ALLOCATOR_H_ + +#include "graph/op_desc.h" + +namespace fe { +class ScopeAllocator { + public: + ScopeAllocator(); + virtual ~ScopeAllocator(); + ScopeAllocator(const ScopeAllocator& in) = delete; + ScopeAllocator& operator = (const ScopeAllocator& in) = delete; + + public: + void Init(); + int64_t GetCurrentScopeId(); + int64_t AllocateScopeId(void); + bool HasScopeAttr(ge::ConstOpDescPtr opdef); + bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t &scope_id); + bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scope_id); + bool ResetScopeId(int64_t scope_id); + private: + int64_t scope_id; +}; +} // namespace fe +#endif diff --git a/metadef/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h b/metadef/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h new file mode 100644 index 00000000..9131b1ba --- /dev/null +++ b/metadef/inc/common/util/ai_core/param_calculate/tensorsize_calculator.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORSIZE_CALCULATOR_H +#define TENSORSIZE_CALCULATOR_H + +#include "graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include "graph/compute_graph.h" +#include "graph/op_desc.h" + +namespace fe { +class TensorSizeCalculator { + public: + /** + * Calculate the tensor size of input and output of each opdesc + * @param op_desc opdesc object + * @param op_impl_type op impl type + * @return status SUCCESS or FAILED + */ + static Status CalculateOpTensorSize(ge::OpDesc &op_desc); + + private: + static Status CalcInputOpTensorSize(ge::OpDesc &op_desc, + int32_t &output_real_calc_flag); + + static Status CalcOutputOpTensorSize(ge::OpDesc &op_desc, + int32_t &output_real_calc_flag); +}; +} // namespace fe + +#endif // TENSORSIZE_CALCULATOR_H diff --git a/metadef/inc/common/util/compress/compress.h b/metadef/inc/common/util/compress/compress.h new file mode 100644 index 00000000..b702324e --- /dev/null +++ b/metadef/inc/common/util/compress/compress.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_H +#define COMPRESS_H + +#include + +enum CmpStatus { + RET_SUCCESS = 0, + RET_ERROR = -1 +}; + +struct CompressConfig { + size_t inputSize; // length of data to compress + size_t engineNum; // how many decompress engines + size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) + size_t channel; // channels of L2 or DDR. For load balance + size_t fractalSize; // size of compressing block + bool isTight; // whether compose compressed data tightly + size_t init_offset; +}; + +CmpStatus CompressWeights(char* input, + const CompressConfig& compressConfig, + char* indexs, + char* output, + size_t& compressedLength); + + +#endif // COMPRESS_H diff --git a/metadef/inc/common/util/compress/compress_weight.h b/metadef/inc/common/util/compress/compress_weight.h new file mode 100644 index 00000000..36521a3a --- /dev/null +++ b/metadef/inc/common/util/compress/compress_weight.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_WEIGHT_H +#define COMPRESS_WEIGHT_H + +#include "compress.h" + +const int SHAPE_SIZE_WEIGHT = 4; + +struct CompressOpConfig { + int64_t wShape[SHAPE_SIZE_WEIGHT]; + size_t compressTilingK; + size_t compressTilingN; + struct CompressConfig compressConfig; +}; + +extern "C" CmpStatus CompressWeightsConv2D(const char *const input, + char *const zipBuffer, + char *const infoBuffer, + CompressOpConfig *const param); +#endif // COMPRESS_WEIGHT_H diff --git a/metadef/inc/common/util/error_manager/error_manager.h b/metadef/inc/common/util/error_manager/error_manager.h new file mode 100644 index 00000000..37c1f96c --- /dev/null +++ b/metadef/inc/common/util/error_manager/error_manager.h @@ -0,0 +1,126 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ERROR_MANAGER_H_ +#define ERROR_MANAGER_H_ + +#include +#include +#include +#include + +class ErrorManager { + public: + /// + /// @brief Obtain ErrorManager instance + /// @return ErrorManager instance + /// + static ErrorManager &GetInstance(); + + /// + /// @brief init + /// @return int 0(success) -1(fail) + /// + int Init(); + + /// + /// @brief init + /// @param [in] path: current so path + /// @return int 0(success) -1(fail) + /// + int Init(std::string path); + + /// + /// @brief Report error message + /// @param [in] error_code: error code + /// @param [in] args_map: parameter map + /// @return int 0(success) -1(fail) + /// + int ReportErrMessage(std::string error_code, const std::map &args_map); + + /// + /// @brief output error message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + /// + int OutputErrMessage(int handle); + + /// + /// @brief output message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + /// + int OutputMessage(int handle); + + /// + /// @brief Report error message + /// @param [in] key: vector parameter key + /// @param [in] value: vector parameter value + /// + void ATCReportErrMessage(std::string error_code, const std::vector &key = {}, + const std::vector &value = {}); + + /// + /// @brief report graph compile failed message such as error code and op_name in mstune case + /// @param [in] msg: failed message map, key is error code, value is op_name + /// @return int 0(success) -1(fail) + /// + int ReportMstuneCompileFailedMsg(const std::map &msg); + + /// + /// @brief save graph compile failed message from thread local map to global map + /// @param [in] graph_name: graph name + /// + void SaveMstuneCompileFailedMsg(const std::string &graph_name); + + /// + /// @brief get graph compile failed message in mstune case + /// @param [in] graph_name: graph name + /// @param [out] msg_map: failed message map, key is error code, value is op_name list + /// @return int 0(success) -1(fail) + /// + int GetMstuneCompileFailedMsg(const std::string &graph_name, + std::map> &msg_map); + + private: + struct ErrorInfo { + std::string error_id; + std::string error_message; + std::vector arg_list; + }; + + ErrorManager() {} + ~ErrorManager() {} + + ErrorManager(const ErrorManager &) = delete; + ErrorManager(ErrorManager &&) = delete; + ErrorManager &operator=(const ErrorManager &) = delete; + ErrorManager &operator=(ErrorManager &&) = delete; + + int ParseJsonFile(std::string path); + + int ReadJsonFile(const std::string &file_path, void *handle); + + bool is_init_ = false; + std::mutex mutex_; + std::map error_map_; + std::vector error_messages_; + std::vector warning_messages_; + std::map>> compile_failed_msg_map_; +}; + +#endif // ERROR_MANAGER_H_ diff --git a/metadef/inc/common/util/platform_info.h b/metadef/inc/common/util/platform_info.h new file mode 100644 index 00000000..ab80f830 --- /dev/null +++ b/metadef/inc/common/util/platform_info.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PLATFORM_INFO_H +#define PLATFORM_INFO_H + +#include +#include +#include +#include "platform_info_def.h" + +using std::map; +using std::vector; +using std::string; + +namespace fe { +class PlatformInfoManager { + public: + PlatformInfoManager(const PlatformInfoManager &) = delete; + PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; + + static PlatformInfoManager &Instance(); + uint32_t InitializePlatformInfo(); + uint32_t Finalize(); + + uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platform_info, OptionalInfo &opti_compilation_info); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info, OptionalInfo &opti_compilation_info); + + void SetOptionalCompilationInfo(OptionalInfo &opti_compilation_info); + + private: + PlatformInfoManager(); + ~PlatformInfoManager(); + + uint32_t LoadIniFile(string ini_file_real_path); + + void Trim(string &str); + + uint32_t LoadConfigFile(string real_path); + + string RealPath(const std::string &path); + + string GetSoFilePath(); + + void ParseVersion(map &version_map, string &soc_version, PlatformInfo &platform_info_temp); + + void ParseSocInfo(map &soc_info_map, PlatformInfo &platform_info_temp); + + void ParseCubeOfAICoreSpec(map &ai_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseBufferOfAICoreSpec(map &ai_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseUBOfAICoreSpec(map &ai_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseUnzipOfAICoreSpec(map &ai_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseAICoreSpec(map &ai_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseBufferOfAICoreMemoryRates(map &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); + + void ParseAICoreMemoryRates(map &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); + + void ParseUBOfAICoreMemoryRates(map &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); + + void ParseAICoreintrinsicDtypeMap(map &ai_coreintrinsic_dtype_map, PlatformInfo &platform_info_temp); + + void ParseVectorCoreSpec(map &vector_core_spec_map, PlatformInfo &platform_info_temp); + + void ParseVectorCoreMemoryRates(map &vector_core_memory_rates_map, PlatformInfo &platform_info_temp); + + void ParseCPUCache(map &CPUCacheMap, PlatformInfo &platform_info_temp); + + void ParseVectorCoreintrinsicDtypeMap(map &vector_coreintrinsic_dtype_map, + PlatformInfo &platform_info_temp); + + uint32_t ParsePlatformInfoFromStrToStruct(map> &content_info_map, string &soc_version, + PlatformInfo &platform_info_temp); + + uint32_t AssemblePlatformInfoVector(map> &content_info_map); + + private: + bool init_flag_; + map platform_info_map_; + OptionalInfo opti_compilation_info_; +}; +} // namespace fe +#endif diff --git a/metadef/inc/common/util/platform_info_def.h b/metadef/inc/common/util/platform_info_def.h new file mode 100644 index 00000000..b17319e0 --- /dev/null +++ b/metadef/inc/common/util/platform_info_def.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PLATFORM_INFO_DEF_H +#define PLATFORM_INFO_DEF_H + +#include +#include +#include + +using std::map; +using std::vector; +using std::string; + +namespace fe { +enum MemoryType { DDR = 0, HBM }; + +enum L2Type { Cache = 0, Buff }; + +typedef struct tag_str_info { + string aic_version; + string ccec_aic_version; + string ccec_aiv_version; + string is_support_ai_cpu_compiler; +} StrInfo; + +typedef struct tag_so_c_info { + uint32_t ai_core_cnt; + uint32_t vector_core_cnt; + uint32_t ai_cpu_cnt; + MemoryType memory_type; + uint64_t memory_size; + L2Type l2_type; + uint64_t l2_size; + uint32_t l2PageNum; +} SoCInfo; + +typedef struct tag_ai_core_spec { + double cube_freq; + uint64_t cube_m_size; + uint64_t cube_n_size; + uint64_t cube_k_size; + uint64_t vec_calc_size; + uint64_t l0_a_size; + uint64_t l0_b_size; + uint64_t l0_c_size; + uint64_t l1_size; + uint64_t smask_buffer; + uint64_t ub_size; + uint64_t ubblock_size; + uint64_t ubbank_size; + uint64_t ubbank_num; + uint64_t ubburst_in_one_block; + uint64_t ubbank_group_num; + uint32_t unzip_engines; + uint32_t unzip_max_ratios; + uint32_t unzip_channels; + uint8_t unzip_is_tight; + uint8_t cube_vector_split; +} AiCoreSpec; + +typedef struct tag_ai_core_memory_rates { + double ddr_rate; + double ddr_read_rate; + double ddr_write_rate; + double l2_rate; + double l2_read_rate; + double l2_write_rate; + double l1_to_l0_a_rate; + double l1_to_l0_b_rate; + double l1_to_ub_rate; + double l0_c_to_ub_rate; + double ub_to_l2_rate; + double ub_to_ddr_rate; + double ub_to_l1_rate; +} AiCoreMemoryRates; + +typedef struct tag_vector_core_spec { + double vec_freq; + uint64_t vec_calc_size; + uint64_t smask_buffer; + uint64_t ub_size; + uint64_t ubblock_size; + uint64_t ubbank_size; + uint64_t ubbank_num; + uint64_t ubburst_in_one_block; + uint64_t ubbank_group_num; + uint64_t vector_reg_size; + uint64_t predicate_reg_size; + uint64_t address_reg_size; + uint64_t alignment_reg_size; +} VectorCoreSpec; + +typedef struct tag_vector_core_memory_rates { + double ddr_rate; + double ddr_read_rate; + double ddr_write_rate; + double l2_rate; + double l2_read_rate; + double l2_write_rate; + double ub_to_l2_rate; + double ub_to_ddr_rate; +} VectorCoreMemoryRates; + +typedef struct tag_cpu_cache { + uint32_t AICPUSyncBySW; + uint32_t TSCPUSyncBySW; +} CPUCache; + +typedef struct tag_platform_info { + StrInfo str_info; + SoCInfo soc_info; + AiCoreSpec ai_core_spec; + AiCoreMemoryRates ai_core_memory_rates; + map> ai_core_intrinsic_dtype_map; + VectorCoreSpec vector_core_spec; + VectorCoreMemoryRates vector_core_memory_rates; + CPUCache cpucache; + map> vector_core_intrinsic_dtype_map; +} PlatformInfo; + +typedef struct tag_optional_info { + string soc_version; + string core_type; + uint32_t ai_core_num; + string l1_fusion_flag; +} OptionalInfo; +} // namespace fe +#endif diff --git a/metadef/inc/external/graph/ascend_string.h b/metadef/inc/external/graph/ascend_string.h new file mode 100644 index 00000000..f7be6c33 --- /dev/null +++ b/metadef/inc/external/graph/ascend_string.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ +#define INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ + +#include +#include +#include + +namespace ge { +class AscendString { + public: + AscendString() = default; + + ~AscendString() = default; + + AscendString(const char* name); + + const char* GetString() const; + + bool operator<(const AscendString& d) const; + + bool operator>(const AscendString& d) const; + + bool operator<=(const AscendString& d) const; + + bool operator>=(const AscendString& d) const; + + bool operator==(const AscendString& d) const; + + bool operator!=(const AscendString& d) const; + + private: + std::shared_ptr name_; +}; +} // namespace ge + +namespace std { +template <> +struct hash { + size_t operator()(const ge::AscendString &name) const { + std::string str_name; + if (name.GetString() != nullptr) { + str_name = name.GetString(); + } + return hash()(str_name); + } +}; +} +#endif // INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ diff --git a/metadef/inc/external/graph/attr_value.h b/metadef/inc/external/graph/attr_value.h new file mode 100644 index 00000000..35c0c997 --- /dev/null +++ b/metadef/inc/external/graph/attr_value.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ +#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ + +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "ascend_string.h" + +using std::make_shared; +using std::map; +using std::pair; +using std::string; +using std::to_string; +using std::unique_ptr; +using std::vector; + +namespace ge { +class AttrValueImpl; +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { + public: + using INT = int64_t; + using FLOAT = float; + using STR = std::string; + + AttrValue(); + ~AttrValue() = default; + + // GetValue, not list type + template + graphStatus GetValue(DT &val) const { + T valGet; + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + val = DT(valGet); + return GRAPH_SUCCESS; + } + + template + static T CreateFrom(DT &&val) { + return val; + } + + graphStatus GetValue(AscendString &val); + + std::shared_ptr impl; + + private: +#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; + VALUE_SET_GET_DEC(AttrValue::STR) + VALUE_SET_GET_DEC(AttrValue::INT) + VALUE_SET_GET_DEC(AttrValue::FLOAT) +#undef VALUE_SET_GET_DEC +}; +/*lint +e148*/ +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/metadef/inc/external/graph/ge_error_codes.h b/metadef/inc/external/graph/ge_error_codes.h new file mode 100644 index 00000000..a7e39dd1 --- /dev/null +++ b/metadef/inc/external/graph/ge_error_codes.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ + +namespace ge { +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif +#ifdef __GNUC__ +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif + +using graphStatus = uint32_t; +const graphStatus GRAPH_FAILED = 0xFFFFFFFF; +const graphStatus GRAPH_SUCCESS = 0; +const graphStatus GRAPH_NOT_CHANGED = 1343242304; +const graphStatus GRAPH_PARAM_INVALID = 50331649; +const graphStatus GRAPH_NODE_WITHOUT_CONST_INPUT = 50331648; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ diff --git a/metadef/inc/external/graph/gnode.h b/metadef/inc/external/graph/gnode.h new file mode 100644 index 00000000..90f030a7 --- /dev/null +++ b/metadef/inc/external/graph/gnode.h @@ -0,0 +1,129 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_NODE_H_ +#define INC_EXTERNAL_GRAPH_NODE_H_ + +#include +#include + +#include "./ge_error_codes.h" +#include "./types.h" +#include "./tensor.h" +#include "./ascend_string.h" + +namespace ge { +class AttrValue; +class GNode; +class OpDesc; +class Graph; +class ComputeGraph; +using GNodePtr = std::shared_ptr; +using GraphPtr = std::shared_ptr; +using OpBytes = std::vector; +using OpDescPtr = std::shared_ptr; +using ComputeGraphPtr = std::shared_ptr; + +class NodeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode { + public: + GNode(); + + ~GNode() = default; + + graphStatus GetType(AscendString &type) const; + + graphStatus GetName(AscendString &name) const; + + std::pair GetInDataNodesAndPortIndexs(const int32_t index) const; + + std::vector GetInControlNodes() const; + + std::vector> GetOutDataNodesAndPortIndexs(const int32_t index) const; + + std::vector GetOutControlNodes() const; + + graphStatus GetInputConstData(const int32_t index, Tensor &data) const; + + graphStatus GetInputIndexByName(const AscendString &name, int32_t &index); + + graphStatus GetOutputIndexByName(const AscendString &name, int32_t &index); + + size_t GetInputsSize() const; + + size_t GetOutputsSize() const; + + graphStatus GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const; + + graphStatus UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc); + + graphStatus GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const; + + graphStatus UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc); + + graphStatus GetAttr(const AscendString &name, int64_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, int32_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, uint32_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, float &attr_value) const; + graphStatus GetAttr(const AscendString &name, AscendString &attr_value) const; + graphStatus GetAttr(const AscendString &name, bool &attr_value) const; + graphStatus GetAttr(const AscendString &name, Tensor &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_values) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, OpBytes &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector> &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, ge::DataType &attr_value) const; + graphStatus GetAttr(const AscendString &name, AttrValue &attr_value) const; + + graphStatus SetAttr(const AscendString &name, int64_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, int32_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, uint32_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, float &attr_value) const; + graphStatus SetAttr(const AscendString &name, AscendString &attr_value) const; + graphStatus SetAttr(const AscendString &name, bool &attr_value) const; + graphStatus SetAttr(const AscendString &name, Tensor &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_values) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, OpBytes &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector> &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, ge::DataType &attr_value) const; + graphStatus SetAttr(const AscendString &name, AttrValue &attr_value) const; + + bool HasAttr(const AscendString &name); + + graphStatus GetSubgraph(uint32_t index, GraphPtr &graph) const; + + graphStatus GetALLSubgraphs(std::vector &graph_list) const; + + private: + std::shared_ptr impl_; + friend class NodeAdapter; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_NODE_H_ diff --git a/metadef/inc/external/graph/graph.h b/metadef/inc/external/graph/graph.h new file mode 100644 index 00000000..651e0cf9 --- /dev/null +++ b/metadef/inc/external/graph/graph.h @@ -0,0 +1,126 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ +#define INC_EXTERNAL_GRAPH_GRAPH_H_ + +#include +#include +#include +#include + +#include "./operator.h" +#include "./gnode.h" + +namespace ge { +class Graph; +class GraphImpl; + +using GraphImplPtr = std::shared_ptr; +using GraphPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { + friend class GraphUtils; + + public: + ATTRIBUTED_DEPRECATED(Graph(const char *)) + explicit Graph(const std::string &name); + + explicit Graph(const char *name); + + Graph() = default; + + ~Graph() = default; + + Graph &SetInputs(const std::vector &inputs); + + Graph &SetOutputs(const std::vector &outputs); + + Graph &SetOutputs(const std::vector>> &output_indexs); + + ATTRIBUTED_DEPRECATED(Graph &SetOutputs(const std::vector> &outputs); + + Graph &SetOutputs(const std::vector> &outputs); + + Graph &SetTargets(const std::vector &targets); + + bool IsValid() const; + + graphStatus AddOp(const ge::Operator &op); + + ATTRIBUTED_DEPRECATED(graphStatus FindOpByName(const char *, ge::Operator &)) + graphStatus FindOpByName(const std::string &name, ge::Operator &op) const; + + graphStatus FindOpByName(const char *name, ge::Operator &op) const; + + ATTRIBUTED_DEPRECATED(graphStatus FindOpByType(const char *, std::vector &)) + graphStatus FindOpByType(const std::string &type, std::vector &ops) const; + + graphStatus FindOpByType(const char *type, std::vector &ops) const; + + ATTRIBUTED_DEPRECATED(graphStatus GetAllOpName(std::vector &) const) + graphStatus GetAllOpName(std::vector &op_name) const; + + graphStatus GetAllOpName(std::vector &names) const; + + ATTRIBUTED_DEPRECATED(graphStatus SaveToFile(const char *file_name) const) + graphStatus SaveToFile(const std::string &file_name) const; + + graphStatus SaveToFile(const char *file_name) const; + + ATTRIBUTED_DEPRECATED(graphStatus LoadFromFile(const char *)) + graphStatus LoadFromFile(const std::string &file_name); + + graphStatus LoadFromFile(const char *file_name); + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) + const std::string &GetName() const; + + graphStatus GetName(AscendString &name) const; + + /// + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration need_iteration:whether to set iteration or not + /// + void SetNeedIteration(bool need_iteration); + + std::vector GetAllNodes() const; + + std::vector GetDirectNode () const; + + graphStatus RemoveNode(GNode &node); + + graphStatus RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index); + + GNode AddNodeByOp(const Operator &op); + + graphStatus AddDataEdge(GNode &src_node, const int32_t src_port_index, + GNode &dst_node, const int32_t dst_port_index); + + graphStatus AddControlEdge(GNode &src_node, GNode &dst_node); + + static GraphPtr ConstructFromInputs(const std::vector &inputs, const AscendString &name); + + private: + + GraphImplPtr impl_{nullptr}; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/metadef/inc/external/graph/inference_context.h b/metadef/inc/external/graph/inference_context.h new file mode 100644 index 00000000..7c2cac2e --- /dev/null +++ b/metadef/inc/external/graph/inference_context.h @@ -0,0 +1,82 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ +#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ + +#include +#include +#include + +#include "./tensor.h" +#include "./types.h" +#include "ascend_string.h" + +namespace ge { +class InferenceContext; +using InferenceContextPtr = std::shared_ptr; + +class ShapeAndTypeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { + public: + ShapeAndType(); + ~ShapeAndType() = default; + + ShapeAndType(const Shape &shape, DataType dataType); + + void SetShape(const Shape &shape); + + void SetType(DataType dataType); + + Shape GetShape() const; + + DataType GetDataType() const; + + private: + std::shared_ptr shape_and_type_impl_; +}; + +class InferenceContextImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { + public: + ~InferenceContext() = default; + InferenceContext(const InferenceContext &context) = delete; + InferenceContext(const InferenceContext &&context) = delete; + InferenceContext &operator=(const InferenceContext &context) = delete; + InferenceContext &operator=(const InferenceContext &&context) = delete; + + void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); + const std::vector> &GetInputHandleShapesAndTypes() const; + const std::vector> &GetOutputHandleShapesAndTypes() const; + void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); + void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); + + ATTRIBUTED_DEPRECATED(void SetMarks(const std::vector &)) + void SetMarks(const std::vector &marks); + void SetMarks(const std::vector &marks); + + ATTRIBUTED_DEPRECATED(void GetMarks(std::vector &) const) + const std::vector &GetMarks() const; + void GetMarks(std::vector &marks) const; + + static std::unique_ptr Create(); + + private: + explicit InferenceContext(std::unique_ptr &impl); + std::shared_ptr inference_context_impl_; +}; +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ diff --git a/metadef/inc/external/graph/operator.h b/metadef/inc/external/graph/operator.h new file mode 100644 index 00000000..81c55757 --- /dev/null +++ b/metadef/inc/external/graph/operator.h @@ -0,0 +1,455 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./inference_context.h" +#include "./tensor.h" + +#ifndef USER_GE_LOGI +#define USER_GE_LOGI(...) +#endif // USER_GE_LOGI + +#ifndef USER_GE_LOGW +#define USER_GE_LOGW(...) +#endif // USER_GE_LOGW + +#ifndef USER_GE_LOGE +#define USER_GE_LOGE(...) +#endif // USER_GE_LOGE + +#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") +#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") + +namespace ge { +class Operator; +class OperatorImpl; +class NodeUtils; +class NamedAttrs; +class Graph; +class AttrValue; +class Node; + +using SubgraphBuilder = std::function; +using OperatorImplPtr = std::shared_ptr; +using OperatorPtr = std::shared_ptr; + +class OpIO; +using OutHandler = std::shared_ptr; +using InHandler = std::shared_ptr; + +using std::function; +using std::shared_ptr; +using std::string; + +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { + public: + friend class OperatorImpl; + friend class GraphBuilderImpl; + friend class NodeUtils; + + using OpInt = int64_t; + using OpFloat = float; + using OpString = string; + using OpAscendString = AscendString; + using OpBool = bool; + using OpTensor = Tensor; + using OpType = ge::DataType; + using OpNamedAttrs = ge::NamedAttrs; + using OpListInt = std::vector; + using OpListFloat = std::vector; + using OpListString = std::vector; + using OpListAcendString = std::vector; + using OpListBool = std::vector; + using OpListTensor = std::vector; + using OpBytes = std::vector; + using OpListListInt = std::vector>; + using OpListType = std::vector; + using OpListNamedAttrs = std::vector; + + Operator() {} + ATTRIBUTED_DEPRECATED(Operator(const char *)) + explicit Operator(const string &type); + + explicit Operator(const char *type); + + ATTRIBUTED_DEPRECATED(Operator(const char *, const char *)) + Operator(const string &name, const string &type); // lint !e148 + + Operator(const AscendString &name, const AscendString &type); + + Operator(const char *name, const char *type); + + virtual ~Operator() = default; + + bool IsEmpty() const; + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) + string GetName() const; + + graphStatus GetName(AscendString &name) const; + + ATTRIBUTED_DEPRECATED(graphStatus GetOpType(AscendString &) const) + string GetOpType() const; + + graphStatus GetOpType(AscendString &type) const; + + // Only has one output index = 0 + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &)) + Operator &SetInput(const string &dst_name, const Operator &src_oprt); + + Operator &SetInput(const char *dst_name, const Operator &src_oprt); + + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &, const char *)) + Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 + + Operator &SetInput(const char *dst_name, const Operator &src_oprt, const char *name); + + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &, uint32_t)) + Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); + + Operator &SetInput(const char *dst_name, const Operator &src_oprt, uint32_t index); + + Operator &AddControlInput(const Operator &src_oprt); + + ATTRIBUTED_DEPRECATED(graphStatus GetInputConstData(const char *, Tensor &) const) + graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; + + graphStatus GetInputConstData(const char *dst_name, Tensor &data) const; + + ATTRIBUTED_DEPRECATED(TensorDesc GetInputDescByName(const char *) const) + TensorDesc GetInputDesc(const string &name) const; + + TensorDesc GetInputDescByName(const char *name) const; + + TensorDesc GetInputDesc(uint32_t index) const; + + ATTRIBUTED_DEPRECATED(int GetDynamicOutputNum(const char *) const) + int GetDynamicOutputNum(const string &name) const; + + int GetDynamicOutputNum(const char *name) const; + + ATTRIBUTED_DEPRECATED(int GetDynamicInputNum(const char *)) + int GetDynamicInputNum(const string &name) const; + + int GetDynamicInputNum(const char *name) const; + + ATTRIBUTED_DEPRECATED(graphStatus TryGetInputDesc(const char *, TensorDesc &) const) + graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; + + graphStatus TryGetInputDesc(const char *name, TensorDesc &tensor_desc) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateInputDesc(const char *, const TensorDesc &)) + graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); + + graphStatus UpdateInputDesc(const char *name, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetOutputDescByName(const char *) const) + TensorDesc GetOutputDesc(const string &name) const; + + TensorDesc GetOutputDescByName(const char *name) const; + + TensorDesc GetOutputDesc(uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateOutputDesc(const char *, const TensorDesc &tensor_desc)) + graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 + + graphStatus UpdateOutputDesc(const char *name, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicInputDesc(const char *, uint32_t) const) + TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; + + TensorDesc GetDynamicInputDesc(const char *name, uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicInputDesc(const char *, uint32_t, const TensorDesc &)) + graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 + + graphStatus UpdateDynamicInputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicOutputDesc(const char *, uint32_t) const) + TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; + + TensorDesc GetDynamicOutputDesc(const char *name, uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicOutputDesc(const char *, uint32_t, const TensorDesc &)) + graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 + + graphStatus UpdateDynamicOutputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc); + + graphStatus InferShapeAndType(); // lint !e148 + + void SetInferenceContext(const InferenceContextPtr &inference_context); + InferenceContextPtr GetInferenceContext() const; + + graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 + + size_t GetInputsSize() const; + + size_t GetOutputsSize() const; + + ATTRIBUTED_DEPRECATED(graphStatus GetAllAttrNamesAndTypes(std::map &) const) + const std::map GetAllAttrNamesAndTypes() const; + + graphStatus GetAllAttrNamesAndTypes(std::map &attr_name_types) const; + + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, int64_t)) + Operator &SetAttr(const string &name, int64_t attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, int32_t)) + Operator &SetAttr(const string &name, int32_t attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, uint32_t)) + Operator &SetAttr(const string &name, uint32_t attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, int64_t &) const) + graphStatus GetAttr(const string &name, int64_t &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, int32_t &) const) + graphStatus GetAttr(const string &name, int32_t &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, uint32_t &) const) + graphStatus GetAttr(const string &name, uint32_t &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, std::initializer_list &&)) + Operator &SetAttr(const string &name, std::initializer_list &&attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *name, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *name, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const string &, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, float attr_value)) + Operator &SetAttr(const string &name, float attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, float &) const) + graphStatus GetAttr(const string &name, float &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, AttrValue &&)) + Operator &SetAttr(const string &name, AttrValue &&attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, AttrValue &) const) + graphStatus GetAttr(const string &name, AttrValue &attr_value) const; + Operator &SetAttr(const string &name, const string &attr_value); + graphStatus GetAttr(const string &name, string &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, bool)) + Operator &SetAttr(const string &name, bool attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, bool &) const) + graphStatus GetAttr(const string &name, bool &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const Tensor &)) + Operator &SetAttr(const string &name, const Tensor &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, Tensor &) const) + graphStatus GetAttr(const string &name, Tensor &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + // Bytes type + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const OpBytes &)) + Operator &SetAttr(const string &name, const OpBytes &attr_value); + // Bytes type + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, OpBytes &) const) + graphStatus GetAttr(const string &name, OpBytes &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector> &)) + Operator &SetAttr(const string &name, const std::vector> &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector> &) const) + graphStatus GetAttr(const string &name, std::vector> &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const ge::DataType &)) + Operator &SetAttr(const string &name, const ge::DataType &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, ge::DataType &) const) + graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; + + // func type + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const ge::NamedAttrs &)) + Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, ge::NamedAttrs &) const) + graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) + Operator &SetAttr(const string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + + Operator &SetAttr(const char *name, int64_t attr_value); + Operator &SetAttr(const char *name, int32_t attr_value); + Operator &SetAttr(const char *name, uint32_t attr_value); + graphStatus GetAttr(const char *name, int64_t &attr_value) const; + graphStatus GetAttr(const char *name, int32_t &attr_value) const; + graphStatus GetAttr(const char *name, uint32_t &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_value); + Operator &SetAttr(const char *name, const std::vector &attr_value); + Operator &SetAttr(const char *name, const std::vector &attr_value); + Operator &SetAttr(const char *name, std::initializer_list &&attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + + Operator &SetAttr(const char *name, float attr_value); + graphStatus GetAttr(const char *name, float &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + Operator &SetAttr(const char *name, AttrValue &&attr_value); + graphStatus GetAttr(const char *name, AttrValue &attr_value) const; + + Operator &SetAttr(const char *name, const char *attr_value); + Operator &SetAttr(const char *name, const AscendString &attr_value); + graphStatus GetAttr(const char *name, AscendString &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_values); + graphStatus GetAttr(const char *name, std::vector &attr_values) const; + + Operator &SetAttr(const char *name, bool attr_value); + graphStatus GetAttr(const char *name, bool &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + + Operator &SetAttr(const char *name, const Tensor &attr_value); + graphStatus GetAttr(const char *name, Tensor &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + + // Bytes type + Operator &SetAttr(const char *name, const OpBytes &attr_value); + // Bytes type + graphStatus GetAttr(const char *name, OpBytes &attr_value) const; + + Operator &SetAttr(const char *name, const std::vector> &attr_value); + graphStatus GetAttr(const char *name, std::vector> &attr_value) const; + + Operator &SetAttr(const char *name, const std::vector &attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + + Operator &SetAttr(const char *name, const ge::DataType &attr_value); + graphStatus GetAttr(const char *name, ge::DataType &attr_value) const; + + // func type + Operator &SetAttr(const char *name, const ge::NamedAttrs &attr_value); + graphStatus GetAttr(const char *name, ge::NamedAttrs &attr_value) const; + Operator &SetAttr(const char *name, const std::vector &attr_value); + graphStatus GetAttr(const char *name, std::vector &attr_value) const; + + void BreakConnect() const; + + size_t GetSubgraphNamesCount() const; + ATTRIBUTED_DEPRECATED(graphStatus GetSubgraphNames(std::vector &) const) + std::vector GetSubgraphNames() const; + graphStatus GetSubgraphNames(std::vector &names) const; + ATTRIBUTED_DEPRECATED(SubgraphBuilder GetSubgraphBuilder(const char *) const) + SubgraphBuilder GetSubgraphBuilder(const string &name) const; + SubgraphBuilder GetSubgraphBuilder(const char *name) const; + ATTRIBUTED_DEPRECATED(Graph GetSubgraph(const char *) const) + Graph GetSubgraph(const string &name) const; + Graph GetSubgraph(const char *name) const; + ATTRIBUTED_DEPRECATED(SubgraphBuilder GetDynamicSubgraphBuilder(const char *, uint32_t) const) + SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; + SubgraphBuilder GetDynamicSubgraphBuilder(const char *name, uint32_t index) const; + ATTRIBUTED_DEPRECATED(Graph GetDynamicSubgraph(const char *, uint32_t) const) + Graph GetDynamicSubgraph(const string &name, uint32_t index) const; + Graph GetDynamicSubgraph(const char *name, uint32_t index) const; + + protected: + void AttrRegister(const string &name, float attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, int64_t attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const string &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, bool attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const Tensor &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const OpBytes &attr_value); + void AttrRegister(const string &name, const std::vector> &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const ge::DataType &attr_value); + void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + void AttrRegister(const string &name, const AscendString &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); + + explicit Operator(OperatorImplPtr &&op_impl); + + void InputRegister(const string &name); + + void OptionalInputRegister(const string &name); + + void InferFuncRegister(const std::function &func); + + void VerifierFuncRegister(const std::function &func); + + void InferFormatFuncRegister(const std::function &func); + + void OutputRegister(const string &name); + + void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); + + void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); + + void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); + + void RequiredAttrRegister(const string &name); + + graphStatus VerifyAll(); // lint !e148 + + // Only has one output index = 0 + Operator &SetInput(const string &dst_name, uint32_t dst_index, + const Operator &src_oprt); + + Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, + const string &name); // lint !e148 + + void SubgraphRegister(const string &ir_name, bool dynamic); + void SubgraphCountRegister(const string &ir_name, uint32_t count); + void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); + Graph GetSubgraphImpl(const string &name) const; + + private: + Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 + + OutHandler GetOutput(const string &name) const; + + OutHandler GetOutput(uint32_t index) const; + + OperatorImplPtr GetOperatorImplPtr() const; + + OperatorImplPtr operator_impl_{nullptr}; + + graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; + + std::shared_ptr GetNode() const; +}; +/*lint +e148*/ +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/metadef/inc/external/graph/operator_factory.h b/metadef/inc/external/graph/operator_factory.h new file mode 100644 index 00000000..82326572 --- /dev/null +++ b/metadef/inc/external/graph/operator_factory.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ + +#include +#include +#include +#include + +#include "./operator.h" +#include "./ge_error_codes.h" + +namespace ge { +using OpCreator = std::function; +using OpCreatorV2 = std::function; +using InferShapeFunc = std::function; +using InferFormatFunc = std::function; +using VerifyFunc = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { + public: + ATTRIBUTED_DEPRECATED(static Operator CreateOperator(const char *, const char *)) + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static Operator CreateOperator(const char *operator_name, const char *operator_type); + + ATTRIBUTED_DEPRECATED(graphStatus GetOpsTypeList(std::vector &)) + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + ATTRIBUTED_DEPRECATED(bool IsExistOp(const char *)) + static bool IsExistOp(const string &operator_type); + + static bool IsExistOp(const char *operator_type); +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { + public: + ATTRIBUTED_DEPRECATED(OperatorCreatorRegister(const char *, OpCreatorV2 const &)) + OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator); + OperatorCreatorRegister(const char *operator_type, OpCreatorV2 const &op_creator); + ~OperatorCreatorRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { + public: + ATTRIBUTED_DEPRECATED(InferShapeFuncRegister(const char *, const InferShapeFunc &)) + InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); + InferShapeFuncRegister(const char *operator_type, const InferShapeFunc &infer_shape_func); + ~InferShapeFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { + public: + ATTRIBUTED_DEPRECATED(InferFormatFuncRegister(const char *, const InferFormatFunc &)) + InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); + InferFormatFuncRegister(const char *operator_type, const InferFormatFunc &infer_format_func); + ~InferFormatFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { + public: + ATTRIBUTED_DEPRECATED(VerifyFuncRegister(const char *, const VerifyFunc &)) + VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); + VerifyFuncRegister(const char *operator_type, const VerifyFunc &verify_func); + ~VerifyFuncRegister() = default; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ diff --git a/metadef/inc/external/graph/operator_reg.h b/metadef/inc/external/graph/operator_reg.h new file mode 100644 index 00000000..9887f8dc --- /dev/null +++ b/metadef/inc/external/graph/operator_reg.h @@ -0,0 +1,561 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ + +#include +#include +#include +#include + +#include "graph/operator.h" +#include "graph/operator_factory.h" +#include "graph/tensor.h" +#include "graph/types.h" +#include "graph/graph.h" + +namespace ge { +using std::function; +using std::string; +using std::vector; + +#define ATTR_String(x, ...) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + string ret_str = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + ret = AscendString(ret_str.c_str()); \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const char *v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } + +#define ATTR_ListString(x, ...) \ + graphStatus get_attr_##x(vector &ret) const { \ + vector ret_strs = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + for (auto &ret_str : ret_strs) { \ + ret.emplace_back(ret_str.c_str()); \ + } \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const vector &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function()> &v) { \ + return *this; } + +#define ATTR_AscendString(x, ...) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + AscendString ret_str = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + ret = AscendString(ret_str.c_str()); \ + } \ + return GRAPH_SUCCESS; \ + } + +#define ATTR_ListAscendString(x, ...) \ + graphStatus get_attr_##x(vector &ret) const { \ + vector ret_strs = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + for (auto &ret_str : ret_strs) { \ + if (ret_str.GetString() != nullptr) { \ + ret.emplace_back(ret_str.GetString()); \ + } \ + } \ + } \ + return GRAPH_SUCCESS; \ + } + +#define ATTR_Int(x, ...) +#define ATTR_Float(x, ...) +#define ATTR_Bool(x, ...) +#define ATTR_Tensor(x, ...) +#define ATTR_Type(x, ...) +#define ATTR_NamedAttrs(x, ...) +#define ATTR_ListInt(x, ...) +#define ATTR_ListFloat(x, ...) +#define ATTR_ListBool(x, ...) +#define ATTR_ListTensor(x, ...) +#define ATTR_Bytes(x, ...) +#define ATTR_ListListInt(x, ...) +#define ATTR_ListType(x, ...) +#define ATTR_ListNamedAttrs(x, ...) + +#define REQUIRED_ATTR_String(x) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const char *v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } + +#define REQUIRED_ATTR_ListString(x) \ + graphStatus get_attr_##x(vector &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const vector &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function()> &v) { \ + return *this; } + +#define REQUIRED_ATTR_AscendString(x) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED \ + } \ + return GRAPH_SUCCESS; \ + } + +#define REQUIRED_ATTR_ListAscendString(x) \ + graphStatus get_attr_##x(vector &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } + +#define REQUIRED_ATTR_Int(x) +#define REQUIRED_ATTR_Float(x) +#define REQUIRED_ATTR_Bool(x) +#define REQUIRED_ATTR_Tensor(x) +#define REQUIRED_ATTR_Type(x) +#define REQUIRED_ATTR_NamedAttrs(x) +#define REQUIRED_ATTR_ListInt(x) +#define REQUIRED_ATTR_ListFloat(x) +#define REQUIRED_ATTR_ListBool(x) +#define REQUIRED_ATTR_ListTensor(x) +#define REQUIRED_ATTR_Bytes(x) +#define REQUIRED_ATTR_ListListInt(x) +#define REQUIRED_ATTR_ListType(x) +#define REQUIRED_ATTR_ListNamedAttrs(x) + +class OpReg { + public: + OpReg &N() { return *this; } + + OpReg &ATTR() { return *this; } + + OpReg &REQUIRED_ATTR() { return *this; } + + OpReg &INPUT() { return *this; } + + OpReg &OPTIONAL_INPUT() { return *this; } + + OpReg &OUTPUT() { return *this; } + + OpReg &GRAPH() { return *this; } + + OpReg &DYNAMIC_GRAPH() { return *this; } + + OpReg &INFER_SHAPE_AND_TYPE() { return *this; } +}; + +#define REG_OP(x) \ + namespace op { \ + class x : public Operator { \ + typedef x _THIS_TYPE; \ + \ + public: \ + ATTRIBUTED_DEPRECATED(x(const char *)) \ + explicit x(const string &name) : Operator(name.c_str(), #x) { __##x(); } \ + explicit x(const char *name) : Operator(name, #x) { __##x(); } \ + explicit x(const AscendString &name) : Operator(name, #x) { \ + __##x(); } \ + x() : Operator(#x) { __##x(); } \ + \ + private: \ + void __##x() { \ + OpReg() + +#define ATTR(x, Type, ...) \ + N(); \ + __attr_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ + static const string name_attr_##x() { return #x; } \ + static const void name_attr_##x(AscendString &attr) { \ + attr = AscendString(#x); \ + } \ + ATTR_##Type(x, __VA_ARGS__) \ + Op##Type get_attr_##x() const { \ + Op##Type ret = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ + \ + private: \ + void __attr_##x() { \ + Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ + string attr_name(#x); \ + (void)OpReg() + +#define REQUIRED_ATTR(x, Type) \ + N(); \ + __required_attr_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ + static const string name_attr_##x() { return #x; } \ + static const void name_attr_##x(AscendString &attr_name) { \ + attr_name = AscendString(#x); \ + } \ + REQUIRED_ATTR_##Type(x) \ + Op##Type get_attr_##x() const { \ + Op##Type ret; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ + \ + private: \ + void __required_attr_##x() { \ + Operator::RequiredAttrRegister(#x); \ + string attr_name(#x); \ + (void)OpReg() + +#define INPUT(x, t) \ + N(); \ + __input_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ + static const string name_in_##x() { return #x; } \ + static const void name_in_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ + _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ + Operator::SetInput(#x, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { return Operator::GetInputDescByName(#x); } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __input_##x() { \ + Operator::InputRegister(#x); \ + (void)OpReg() + +#define OPTIONAL_INPUT(x, t) \ + N(); \ + __optional_input_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ + static const string name_in_##x() { return #x; } \ + static const void name_in_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ + _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ + Operator::SetInput(#x, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { return Operator::GetInputDescByName(#x); } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __optional_input_##x() { \ + Operator::OptionalInputRegister(#x); \ + (void)OpReg() + +#define OUTPUT(x, t) \ + N(); \ + __out_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_out_##x(AscendString &)) \ + static const string name_out_##x() { return #x; } \ + static const void name_out_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + TensorDesc get_output_desc_##x() const { return Operator::GetOutputDescByName(#x); } \ + graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateOutputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void __out_##x() { \ + Operator::OutputRegister(#x); \ + (void)OpReg() + +#define DYNAMIC_INPUT(x, t) \ + N(); \ + __dy_input_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicInputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ + Operator::DynamicInputRegisterByIndex(#x, num, index); \ + return *this; \ + } \ + TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { \ + return Operator::GetDynamicInputDesc(#x, index); \ + } \ + graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ + Operator::SetInput(#x, dstIndex, v); \ + return *this; \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_dynamic_input_##x(uint32_t, Operator &, const char *))\ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const char *srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName); \ + return *this; \ + } \ + \ + private: \ + void __dy_input_##x() { \ + Operator::DynamicInputRegister(#x, 0, true); \ + (void)OpReg() + +#define DYNAMIC_OUTPUT(x, t) \ + N(); \ + __dy_output_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicOutputRegister(#x, num, isPushBack); \ + return *this; \ + } \ + TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { \ + return Operator::GetDynamicOutputDesc(#x, index); \ + } \ + graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ + } \ + \ + private: \ + void __dy_output_##x() { \ + Operator::DynamicOutputRegister(#x, 0, true); \ + (void)OpReg() + +#define GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ + static const string name_graph_##x() { return #x; } \ + static const void name_graph_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + SubgraphBuilder get_subgraph_builder_##x() const { \ + return Operator::GetSubgraphBuilder(#x); \ + } \ + _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, 0, v); \ + return *this; \ + } \ + Graph get_subgraph_##x() const { \ + return Operator::GetSubgraph(#x); \ + } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, false); \ + Operator::SubgraphCountRegister(#x, 1); \ + (void)OpReg() + +#define DYNAMIC_GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ + static const string name_graph_##x() { return #x; } \ + static const void name_graph_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ + Operator::SubgraphCountRegister(#x, num); \ + return *this; \ + } \ + SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraphBuilder(#x, index); \ + } \ + Graph get_dynamic_subgraph_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraph(#x, index); \ + } \ + _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index,const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, index, v); \ + return *this; \ + } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, true); \ + (void)OpReg() + + +#define PASTE(g_register, y) g_register##y +#define __OP_END_IMPL__(x, y) \ + N(); \ + } \ + static_assert( \ + std::is_same::value, \ + "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ + } \ + ; \ + static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const AscendString &name) { return x(name); }); \ + } +#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) + +// Specialized shape inferencer macro + +#define IMPLEMT_INFERFUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +#define IMPLEMT_COMMON_INFERFUNC(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) + +#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +// Specialized verifier macro + +#define IMPLEMT_VERIFIER(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) + +#define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } + +#define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); } + +#define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } + +#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) + +#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) +// Infer format func register +#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ + static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) + +// Shape inferencer & verifier register macro + +#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__) + +#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +// Infer format func reg +#define INFER_FORMAT_FUNC_REG(op_name, x) \ + __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) + +// Common shape inferencer + +#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ + [](Operator op)->graphStatus { \ + auto x_shape = op.GetInputDescByName(in_name).GetShape().GetDims(); \ + auto x_type = op.GetInputDescByName(in_name).GetDataType(); \ + TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ + op_output_desc.SetShape(ge::Shape(x_shape)); \ + op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ + op_output_desc.SetDataType(x_type); \ + return op.UpdateOutputDesc(out_name, op_output_desc); \ + } + +graphStatus BroadCastInfer(const function()> &get_in1_shape, + const function()> &get_in2_shape, + const function &y_shape)> &set_out_shape); + +#define BROADCAST_INFER(in1_name, in2_name, out_name) \ + [](Operator op) -> graphStatus { \ + return BroadCastInfer([&]() { return op.GetInputDescByName(in1_name).GetShape().GetDims(); }, \ + [&]() { return op.GetInputDescByName(in2_name).GetShape().GetDims(); }, \ + [&](const vector &y_shape) { \ + TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ + op_output_desc.SetShape(ge::Shape(y_shape)); \ + (void)op.UpdateOutputDesc(out_name, op_output_desc);}); \ + } +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/metadef/inc/external/graph/tensor.h b/metadef/inc/external/graph/tensor.h new file mode 100644 index 00000000..e199fb6a --- /dev/null +++ b/metadef/inc/external/graph/tensor.h @@ -0,0 +1,140 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ +#define INC_EXTERNAL_GRAPH_TENSOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./types.h" +#include "ascend_string.h" + +namespace ge { +class ShapeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { + public: + Shape(); + ~Shape() = default; + explicit Shape(const std::vector &dims); + + size_t GetDimNum() const; + // If the idx is invalid, return 0 + int64_t GetDim(size_t idx) const; + graphStatus SetDim(size_t idx, int64_t value); + std::vector GetDims() const; + int64_t GetShapeSize() const; + + private: + std::shared_ptr impl_; +}; + +class TensorDescImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { + public: + TensorDesc(); + ~TensorDesc() = default; + explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + // Copy + TensorDesc(const TensorDesc &desc); + // Move + TensorDesc(TensorDesc &&desc); + // Copy + TensorDesc &operator=(const TensorDesc &desc); + // Move + TensorDesc &operator=(TensorDesc &&desc); + + void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + Shape GetShape() const; + void SetShape(const Shape &shape); + // set shape with -2, it stand for unknown shape + graphStatus SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + + Format GetFormat() const; + void SetFormat(Format format); + + Shape GetOriginShape() const; + void SetOriginShape(const Shape &originShape); + + Format GetOriginFormat() const; + void SetOriginFormat(Format originFormat); + + DataType GetDataType() const; + void SetDataType(DataType dt); + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &)) + std::string GetName() const; + graphStatus GetName(AscendString &name); + + ATTRIBUTED_DEPRECATED(void SetName(const char *)) + void SetName(const std::string &name); + void SetName(const char *name); + + // Attr acess + void SetSize(int64_t size); + int64_t GetSize() const; + + int64_t GetRealDimCnt() const; + void SetRealDimCnt(const int64_t realDimCnt); + + private: + std::shared_ptr impl; +}; + +class TensorImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { + public: + Tensor(); + ~Tensor() = default; + explicit Tensor(const TensorDesc &tensorDesc); + Tensor(const TensorDesc &tensorDesc, const std::vector &data); + Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size); + Tensor(TensorDesc &&tensorDesc, std::vector &&data); + + TensorDesc GetTensorDesc() const; + graphStatus SetTensorDesc(const TensorDesc &tensorDesc); + + const uint8_t *GetData() const; + uint8_t *GetData(); + size_t GetSize() const; + + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const uint8_t *data, size_t size); + ATTRIBUTED_DEPRECATED(graphStatus SetData(const char *data)) + graphStatus SetData(const std::string &data); + graphStatus SetData(const char *data); + ATTRIBUTED_DEPRECATED(graphStatus SetData(const std::vector &)) + graphStatus SetData(const std::vector &data); + graphStatus SetData(const std::vector &datas); + graphStatus IsValid(); + + Tensor Clone() const; + + private: + std::shared_ptr impl; + friend class TensorAdapter; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/metadef/inc/external/graph/types.h b/metadef/inc/external/graph/types.h new file mode 100644 index 00000000..7bfe19d4 --- /dev/null +++ b/metadef/inc/external/graph/types.h @@ -0,0 +1,245 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GRAPH_TYPES_H_ +#define INC_EXTERNAL_GRAPH_TYPES_H_ + +#include +#include +#include + +namespace ge { +static const int64_t UNKNOWN_DIM = -1; +static const int64_t UNKNOWN_DIM_NUM = -2; +static const std::vector UNKNOWN_SHAPE = {-1}; +static const std::vector UNKNOWN_RANK = {-2}; + +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +enum DataType { + DT_FLOAT = 0, // float type + DT_FLOAT16 = 1, // fp16 type + DT_INT8 = 2, // int8 type + DT_INT16 = 6, // int16 type + DT_UINT16 = 7, // uint16 type + DT_UINT8 = 4, // uint8 type + DT_INT32 = 3, // + DT_INT64 = 9, // int64 type + DT_UINT32 = 8, // unsigned int32 + DT_UINT64 = 10, // unsigned int64 + DT_BOOL = 12, // bool type + DT_DOUBLE = 11, // double type + DT_STRING = 13, // string type + DT_DUAL_SUB_INT8 = 14, // dual output int8 type + DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type + DT_COMPLEX64 = 16, // complex64 type + DT_COMPLEX128 = 17, // complex128 type + DT_QINT8 = 18, // qint8 type + DT_QINT16 = 19, // qint16 type + DT_QINT32 = 20, // qint32 type + DT_QUINT8 = 21, // quint8 type + DT_QUINT16 = 22, // quint16 type + DT_RESOURCE = 23, // resource type + DT_STRING_REF = 24, // string ref type + DT_DUAL = 25, // dual output type + DT_UNDEFINED // Used to indicate a DataType field has not been set. +}; + +inline int GetSizeByDataType(DataType data_type) { + static int data_type_size[DT_UNDEFINED] = { + 4, // DT_FLOAT = 0, float type + 2, // DT_FLOAT16 = 1, fp16 type + 1, // DT_INT8 = 2, int8 type + 4, // DT_INT32 = 3, + 1, // DT_UINT8 = 4, uint8 type + -1, + 2, // DT_INT16 = 6, int16 type + 2, // DT_UINT16 = 7, uint16 type + 4, // DT_UINT32 = 8, unsigned int32 + 8, // DT_INT64 = 9, int64 type + 8, // DT_UINT64 = 10, unsigned int64 + 8, // DT_DOUBLE = 11, double type + 1, // DT_BOOL = 12, bool type + -1, // DT_STRING = 13, string type + 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type + 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type + 8, // DT_COMPLEX64 = 16, complex64 type + 16, // DT_COMPLEX128 = 17, complex128 type + 1, // DT_QINT8 = 18, qint8 type + 2, // DT_QINT16 = 19, qint16 type + 4, // DT_QINT32 = 20, qint32 type + 1, // DT_QUINT8 = 21, quint8 type + 2, // DT_QUINT16 = 22, quint16 type + -1, // DT_RESOURCE = 23, resource type + -1, // DT_STRING_REF = 24, string ref type + 5, // DT_DUAL = 25, dual output type (float + int8) + // DT_UNDEFINED Used to indicate a DataType field has not been set. + }; + if (data_type >= DT_UNDEFINED) { + return -1; + } + return data_type_size[data_type]; +} + +enum Format { + FORMAT_NCHW = 0, // NCHW + FORMAT_NHWC, // NHWC + FORMAT_ND, // Nd Tensor + FORMAT_NC1HWC0, // NC1HWC0 + FORMAT_FRACTAL_Z, // FRACTAL_Z + FORMAT_NC1C0HWPAD, + FORMAT_NHWC1C0, + FORMAT_FSR_NCHW, + FORMAT_FRACTAL_DECONV, + FORMAT_C1HWNC0, + FORMAT_FRACTAL_DECONV_TRANSPOSE, + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, + FORMAT_NC1HWC0_C04, // NC1HWC0, C0 is 4 + FORMAT_FRACTAL_Z_C04, // FRACZ, C0 is 4 + FORMAT_CHWN, + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, + FORMAT_HWCN, + FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format + FORMAT_BN_WEIGHT, + FORMAT_FILTER_HWCK, // filter input tensor format + FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, + FORMAT_HASHTABLE_LOOKUP_KEYS, + FORMAT_HASHTABLE_LOOKUP_VALUE, + FORMAT_HASHTABLE_LOOKUP_OUTPUT, + FORMAT_HASHTABLE_LOOKUP_HITS = 24, + FORMAT_C1HWNCoC0, + FORMAT_MD, + FORMAT_NDHWC, + FORMAT_FRACTAL_ZZ, + FORMAT_FRACTAL_NZ, + FORMAT_NCDHW, + FORMAT_DHWCN, // 3D filter input tensor format + FORMAT_NDC1HWC0, + FORMAT_FRACTAL_Z_3D, + FORMAT_CN, + FORMAT_NC, + FORMAT_DHWNC, + FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + FORMAT_FRACTAL_ZN_LSTM, + FORMAT_FRACTAL_Z_G, + FORMAT_RESERVED, + FORMAT_ALL, + FORMAT_NULL +}; + +// for unknown shape op type +enum UnknowShapeOpType { + DEPEND_IN_SHAPE = 1, // op out shape get by input shape + DEPEND_CONST_VALUE = 2, // op out shape get by const op value + DEPEND_SHAPE_RANGE = 3, // op out shape get by range + DEPEND_COMPUTE = 4 // op out shape get by totally computing +}; + +struct TensorDescInfo { + Format format_ = FORMAT_RESERVED; // tbe op register support format + DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype +}; + +enum DeviceType { + NPU = 0, + CPU = 1, +}; + +class TensorTypeImpl; +struct TensorType { + explicit TensorType(DataType dt); + + TensorType(const std::initializer_list &types); + + static TensorType ALL() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, + DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } + + static TensorType OrdinaryType() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType BasicType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, + DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType NumberType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType RealNumberType() { + return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; } + + static TensorType IntegerDataType() { + return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } + + static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } + + static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } + + static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } + + static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } + + static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } + + std::shared_ptr tensor_type_impl_; +}; + +enum MEMORY_SIZE_CALC_TYPE { + MEMORY_SIZE_NORMAL = 0, + ALWAYS_EMPTY +}; +} // namespace ge + +namespace domi { +enum class ImplyType : unsigned int { + BUILDIN = 0, // Built in operator, normally executed by OME + TVM, // Compile to TVM bin file for execution + CUSTOM, // User defined calculation logic, executed by CPU + AI_CPU, // AICPU + CCE, // Cce + GELOCAL, // GE local, do node need execute by device + HCCL, // Hccl + INVALID = 0xFFFFFFFF, +}; +} // namespace domi + +#endif // INC_EXTERNAL_GRAPH_TYPES_H_ diff --git a/metadef/inc/external/register/op_tiling_registry.h b/metadef/inc/external/register/op_tiling_registry.h new file mode 100644 index 00000000..f3e6f8f7 --- /dev/null +++ b/metadef/inc/external/register/op_tiling_registry.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_OP_TILING_REGISTRY_H_ +#define INC_REGISTER_OP_TILING_REGISTRY_H_ + +#include +#include +#include +#include +#include +#include "external/register/register_types.h" +#include "external/graph/tensor.h" + +#define REGISTER_OP_TILING(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, __COUNTER__) + +#define REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) + +#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) \ + static OpTilingRegistryInterf g_##optype##TilingRegistryInterf##counter(#optype, opfunc) + +namespace optiling { + +extern thread_local int64_t last_op_tiling_perf; + +enum TensorArgType { + TA_NONE, + TA_SINGLE, + TA_LIST, +}; + +using ByteBuffer = std::stringstream; + +struct TeOpTensor { + std::vector shape; + std::vector ori_shape; + std::string format; + std::string ori_format; + std::string dtype; + std::map attrs; +}; + +struct TeOpTensorArg { + TensorArgType arg_type; + std::vector tensor; +}; + +struct OpRunInfo { + uint32_t block_dim; + std::vector workspaces; + ByteBuffer tiling_data; + bool clear_atomic; +}; + +using TeOpAttrArgs = std::vector; +using TeConstTensorData = std::tuple; + +struct TeOpParas { + std::vector inputs; + std::vector outputs; + std::map const_inputs; + TeOpAttrArgs attrs; + std::string op_type; +}; + +struct OpCompileInfo { + std::string str; + std::string key; +}; + +using OpTilingFunc = std::function; + +using OpTilingFuncPtr = bool (*)(const TeOpParas &, const OpCompileInfo &, OpRunInfo &); + +class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf { + public: + OpTilingRegistryInterf(std::string op_type, OpTilingFunc func); + ~OpTilingRegistryInterf() = default; + static std::map &RegisteredOpInterf(); +}; + +template +ByteBuffer &ByteBufferPut(ByteBuffer &buf, const T &value) { + buf.write(reinterpret_cast(&value), sizeof(value)); + buf.flush(); + return buf; +} + +template +ByteBuffer &ByteBufferGet(ByteBuffer &buf, T &value) { + buf.read(reinterpret_cast(&value), sizeof(value)); + return buf; +} + +size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len); +} // namespace optiling + +#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/metadef/inc/external/register/register.h b/metadef/inc/external/register/register.h new file mode 100644 index 00000000..0de7e051 --- /dev/null +++ b/metadef/inc/external/register/register.h @@ -0,0 +1,201 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "graph/operator.h" +#include "register/register_error_codes.h" +#include "register/register_fmk_types.h" +#include "register/register_types.h" + +using std::unique_ptr; +using std::map; +using std::make_shared; +using std::to_string; +using std::string; +using std::pair; +using std::vector; + +/*lint -e148*/ +namespace ge { +class Operator; +class TensorDesc; +class Tensor; +class TBEPluginManager; +} + +namespace google { +namespace protobuf { +class Message; +} +} + +namespace domi { +const int64_t kMaxNameLength = 1048576; // 1M + +enum DynamicType { + kInvalid = 0, + kInput = 1, + kOutput = 2 +}; +struct DynamicInputOutputInfo { + DynamicType type; // input/output + const char *port_name; + int64_t port_name_len; + const char *attr_name; + int64_t attr_name_len; + DynamicInputOutputInfo() + : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} + DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, + int64_t attr_name_len) + : type(type), + port_name(port_name), + port_name_len(port_name_len), + attr_name(attr_name), + attr_name_len(attr_name_len) {} +}; +Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); +Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, + const vector &dynamic_name_attr_value); +ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFn(const ge::Operator &, ge::Operator &)) +Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); +ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFnDynamic(const ge::Operator &, ge::Operator &, + const vector &)) +Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, + int in_pos = -1, int out_pos = -1); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); +using google::protobuf::Message; +class OpRegistrationDataImpl; + +using ParseParamFunc = std::function; +using ParseParamByOpFunc = std::function; +using FusionParseParamFunc = std::function, + ge::Operator &)>; +using FusionParseParamByOpFunc = std::function &, ge::Operator &)>; +using ParseSubgraphFunc = std::function; +using ParseOpToGraphFunc = std::function; +using ParseSubgraphFuncV2 = std::function; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { + public: + ATTRIBUTED_DEPRECATED(OpRegistrationData(const char *)) + OpRegistrationData(const std::string &om_optype); + + OpRegistrationData(const char *om_optype); + + ~OpRegistrationData(); + + OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const std::vector &)) + OpRegistrationData &OriginOpType(const std::initializer_list &ori_optype_list); + + OpRegistrationData &OriginOpType(const std::vector &ori_op_type_list); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const char *)) + OpRegistrationData &OriginOpType(const std::string &ori_optype); + + OpRegistrationData &OriginOpType(const char *ori_op_type); + + OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); + + OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &)) + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); + + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &subgraph_post_fn); + + OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithCond(int, const char *, bool)) + OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); + + OpRegistrationData &DelInputWithCond(int input_idx, const char *attr_name, bool attr_value); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithOriginalType(int, const char *)) + OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); + + OpRegistrationData &DelInputWithOriginalType(int input_idx, const char *ori_type); + + OpRegistrationData &InputReorderVector(const vector &input_order); + + OpRegistrationData &ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn); + + domi::ImplyType GetImplyType () const; + ATTRIBUTED_DEPRECATED(Status GetOmOptype(ge::AscendString &) const) + std::string GetOmOptype () const; + Status GetOmOptype(ge::AscendString &om_op_type) const; + ATTRIBUTED_DEPRECATED(GetOriginOpTypeSet(std::set &) const) + std::set GetOriginOpTypeSet () const; + Status GetOriginOpTypeSet(std::set &ori_op_type) const; + domi::FrameworkType GetFrameworkType() const; + ParseParamFunc GetParseParamFn() const; + ParseParamByOpFunc GetParseParamByOperatorFn() const; + FusionParseParamFunc GetFusionParseParamFn() const; + FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; + ParseSubgraphFunc GetParseSubgraphPostFn() const; + ParseOpToGraphFunc GetParseOpToGraphFn() const; + Status GetParseSubgraphPostFn(ParseSubgraphFuncV2 &func) const; + + private: + std::shared_ptr impl_; + friend class OpRegistry; + friend class OpRegistrationTbe; + friend class ge::TBEPluginManager; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { + public: + OpReceiver(OpRegistrationData ®_data); + ~OpReceiver() {} +}; + +#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) +#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ + static OpReceiver register_op##ctr \ + __attribute__((unused)) = \ + OpRegistrationData(name) +} // namespace domi + +namespace ge { +using OpRegistrationData = domi::OpRegistrationData; +using OpReceiver = domi::OpReceiver; +} // namespace ge +/*lint +e148*/ +#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/metadef/inc/external/register/register_error_codes.h b/metadef/inc/external/register/register_error_codes.h new file mode 100644 index 00000000..a71bb72c --- /dev/null +++ b/metadef/inc/external/register/register_error_codes.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ + +#define SYSID_FWK 3 // Subsystem ID +#define MODID_COMMON 0 // Common module ID + +#define DECLARE_ERRORNO(sysid, modid, name, value) \ + const domi::Status name = \ + ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); + +#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) + +namespace domi { +using Status = uint32_t; + +// General error code +DECLARE_ERRORNO(0, 0, SUCCESS, 0); +DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); +DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 +DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ diff --git a/metadef/inc/external/register/register_fmk_types.h b/metadef/inc/external/register/register_fmk_types.h new file mode 100644 index 00000000..97616060 --- /dev/null +++ b/metadef/inc/external/register/register_fmk_types.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ + +#include + +namespace domi { +/// +/// @ingroup domi_omg +/// @brief AI framework types +/// +enum FrameworkType { + CAFFE = 0, + MINDSPORE = 1, + TENSORFLOW = 3, + ANDROID_NN, + ONNX, + FRAMEWORK_RESERVED, +}; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ diff --git a/metadef/inc/external/register/register_types.h b/metadef/inc/external/register/register_types.h new file mode 100644 index 00000000..54382672 --- /dev/null +++ b/metadef/inc/external/register/register_types.h @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ + +namespace domi { +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_DEV_VISIBILITY +#endif +#ifdef __GNUC__ +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif + +/// +/// @ingroup domi +/// @brief original tensor type +/// +typedef enum tagDomiTensorFormat { + DOMI_TENSOR_NCHW = 0, // < NCHW + DOMI_TENSOR_NHWC, // < NHWC + DOMI_TENSOR_ND, // < Nd Tensor + DOMI_TENSOR_NC1HWC0, // < NC1HWC0 + DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z + DOMI_TENSOR_NC1C0HWPAD, + DOMI_TENSOR_NHWC1C0, + DOMI_TENSOR_FSR_NCHW, + DOMI_TENSOR_FRACTAL_DECONV, + DOMI_TENSOR_BN_WEIGHT, + DOMI_TENSOR_CHWN, // Android NN Depth CONV + DOMI_TENSOR_FILTER_HWCK, // filter input tensor format + DOMI_TENSOR_NDHWC, + DOMI_TENSOR_NCDHW, + DOMI_TENSOR_DHWCN, // 3D filter input tensor format + DOMI_TENSOR_DHWNC, + DOMI_TENSOR_RESERVED +} domiTensorFormat_t; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ diff --git a/metadef/inc/external/register/scope/scope_fusion_pass_register.h b/metadef/inc/external/register/scope/scope_fusion_pass_register.h new file mode 100644 index 00000000..9df4dd84 --- /dev/null +++ b/metadef/inc/external/register/scope/scope_fusion_pass_register.h @@ -0,0 +1,404 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ +#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ + +#include +#include +#include +#include +#include +#include "ge/ge_api_error_codes.h" +#include "register/register_error_codes.h" +#include "register/register_types.h" +#include "graph/operator.h" + +#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ + do { \ + if (!(cond)) { \ + if ((fusion_rlt) != nullptr) { \ + (fusion_rlt)->SetType(ge::kScopeInvalidType); \ + } \ + return; \ + } \ + } while (0) + +namespace domi { +class TensorFlowModelParser; +} // namespace domi +namespace ge { +const int32_t kFusionDisableIndex = 99999; +const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; +const char *const kScopeInvalidType = "ScopeInvalidType"; +const char *const kInputFromFusionScope = "InputFromFusionScope"; +const char *const kOutputToFusionScope = "OutputToFusionScope"; +class ScopePattern; +using ScopeFusionPatterns = std::vector>; + +class ScopePassManager; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { + public: + Scope(); + ATTRIBUTED_DEPRECATED(Status Init(const char *, const char *, Scope *)) + Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); + Status Init(const char *name, const char *sub_type, Scope *father_scope = nullptr); + ~Scope(); + ATTRIBUTED_DEPRECATED(Status Name(AscendString &) const) + const std::string &Name() const; + Status Name(AscendString &name) const; + ATTRIBUTED_DEPRECATED(Status SubType(AscendString &) const) + const std::string &SubType() const; + Status SubType(AscendString &sub_type) const; + ATTRIBUTED_DEPRECATED(Status AllNodesMap(std::unordered_map &) const) + const std::unordered_map &AllNodesMap() const; + Status AllNodesMap(std::unordered_map &node_map) const; + ATTRIBUTED_DEPRECATED(Scope *GetSubScope(const char *scope_name) const) + Scope *GetSubScope(const std::string &scope_name) const; + Scope *GetSubScope(const char *scope_name) const; + ATTRIBUTED_DEPRECATED(Status LastName(AscendString &) const) + const std::string LastName() const; + Status LastName(AscendString &name) const; + const std::vector &GetAllSubScopes() const; + const Scope *GetFatherScope() const; + + private: + class ScopeImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; + friend class ScopeTree; + friend class NodeOpTypeFeature; + friend class NodeAttrFeature; + friend class ScopeFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { + public: + FusionScopesResult(); + Status Init(); + ~FusionScopesResult(); + ATTRIBUTED_DEPRECATED(void SetName(const char *)) + void SetName(const std::string &name); + void SetName(const char *name); + ATTRIBUTED_DEPRECATED(void SetType(const char *)) + void SetType(const std::string &type); + void SetType(const char *type); + ATTRIBUTED_DEPRECATED(void SetDescription(const char *)) + void SetDescription(const std::string &description); + void SetDescription(const char *description); + ATTRIBUTED_DEPRECATED(const Status Name(AscendString &) const) + const std::string &Name() const; + const Status Name(AscendString &name) const; + const std::vector &Nodes() const; + ATTRIBUTED_DEPRECATED(void InsertInputs(const char *, const std::vector &)) + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertInputs(const char *inner_op_name, const std::vector &index_map); + ATTRIBUTED_DEPRECATED(void InsertOutputs(const char *, const std::vector &)) + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const char *inner_op_name, const std::vector &index_map); + + class InnerNodeInfo { + public: + ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char *)) + explicit InnerNodeInfo(const std::string &fusion_node_name); + explicit InnerNodeInfo(const char *fusion_node_name); + ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char *, const char *, const char *)) + InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); + InnerNodeInfo(const char *fusion_node_name, const char *name, const char *type); + InnerNodeInfo(InnerNodeInfo &&other) noexcept; + InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; + InnerNodeInfo(const InnerNodeInfo &) = delete; + InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; + ~InnerNodeInfo(); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetName(const char *)) + InnerNodeInfo &SetName(const std::string &name); + InnerNodeInfo &SetName(const char *name); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetType(const char *)) + InnerNodeInfo &SetType(const std::string &type); + InnerNodeInfo &SetType(const char *type); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertInput(const char *, int32_t)) + InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); + InnerNodeInfo &InsertInput(const char *input_node, int32_t peer_out_idx); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertOutput(const char *, int32_t)) + InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); + InnerNodeInfo &InsertOutput(const char *output_node, int32_t peer_in_idx); + ge::graphStatus BuildInnerNode(); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetInputFormat(const char *, const char *)) + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); + ge::graphStatus SetInputFormat(const char *input_name, const char *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetOutputFormat(const char *, const char *)) + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetOutputFormat(const char *output_name, const char *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicInputFormat(const char *, uint32_t index, const char *)) + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const char *input_name, uint32_t index, const char *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicOutputFormat(const char *, uint32_t, const char *)) + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const char *output_name, uint32_t index, const char *format); + ge::Operator *MutableOperator(); + ATTRIBUTED_DEPRECATED(ge::graphStatus GetName(AscendString &) const) + std::string GetName() const; + ge::graphStatus GetName(AscendString &name) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetType(AscendString &) const) + std::string GetType() const; + ge::graphStatus GetType(AscendString &type) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetInputs(std::vector> &) const) + std::vector> GetInputs() const; + ge::graphStatus GetInputs(std::vector> &inputs) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetOutputs(std::vector> &) const) + std::vector> GetOutputs() const; + ge::graphStatus GetOutputs(std::vector> &outputs) const; + private: + class InnerNodeInfoImpl; + std::unique_ptr impl_; + }; + ATTRIBUTED_DEPRECATED(InnerNodeInfo *AddInnerNode(const char *, const char *)) + InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); + InnerNodeInfo *AddInnerNode(const char *name, const char *type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + ge::graphStatus CheckInnerNodesInfo(); + + private: + class FusionScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { + public: + ScopeTree(); + Status Init(); + ScopeTree(const ScopeTree &scopetree) = delete; + ScopeTree &operator=(const ScopeTree &scopetree) = delete; + ~ScopeTree(); + + const std::vector &GetAllScopes() const; + + private: + class ScopeTreeImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { + public: + ScopeGraph(); + Status Init(); + ScopeGraph(const ScopeGraph &scope_graph) = delete; + ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; + ~ScopeGraph(); + + const ScopeTree *GetScopeTree() const; + ATTRIBUTED_DEPRECATED(Status GetNodesMap(std::unordered_map &) const) + const std::unordered_map &GetNodesMap() const; + Status GetNodesMap(std::unordered_map &nodes_map) const; + + private: + class ScopeGraphImpl; + std::unique_ptr impl_; + friend class ScopePassManager; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { + public: + ScopeAttrValue(); + ScopeAttrValue(ScopeAttrValue const &attr_value); + ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); + ~ScopeAttrValue(); + + void SetIntValue(int64_t value); + void SetFloatValue(float value); + ATTRIBUTED_DEPRECATED(void SetStringValue(const char *)) + void SetStringValue(std::string value); + void SetStringValue(const char *value); + void SetBoolValue(bool value); + + private: + class ScopeAttrValueImpl; + std::unique_ptr impl_; + friend class NodeAttrFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { + public: + virtual bool Match(const Scope *scope) = 0; + virtual ~ScopeBaseFeature(){}; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(NodeOpTypeFeature(const char *, int, int)) + NodeOpTypeFeature(std::string nodeType, int num, int step = 0); + NodeOpTypeFeature(const char *node_type, int num, int step = 0); + NodeOpTypeFeature(NodeOpTypeFeature const &feature); + NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); + ~NodeOpTypeFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeOpTypeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(NodeAttrFeature(const char *, const char *, ge::DataType, ScopeAttrValue &)) + NodeAttrFeature(std::string nodeType, std::string attr_name, + ge::DataType datatype, ScopeAttrValue &attr_value); + NodeAttrFeature(const char *node_type, const char *attr_name, + ge::DataType datatype, ScopeAttrValue &attr_value); + NodeAttrFeature(NodeAttrFeature const &feature); + NodeAttrFeature &operator=(NodeAttrFeature const &feature); + ~NodeAttrFeature(); + bool Match(const Scope *scope) override; + + private: + class NodeAttrFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(ScopeFeature(const char *, int32_t, const char *, const char *, int)) + ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", + std::string sub_scope_mask = "", int step = 0); + ScopeFeature(const char *sub_type, int32_t num, const char *suffix, + const char *sub_scope_mask, int step = 0); + ScopeFeature(ScopeFeature const &feature); + ScopeFeature &operator=(ScopeFeature const &feature); + ~ScopeFeature(); + bool Match(const Scope *scope) override; + + private: + class ScopeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { + public: + ScopePattern(); + ~ScopePattern(); + ATTRIBUTED_DEPRECATED(ScopePattern &SetSubType(const char *)) + ScopePattern &SetSubType(const std::string &sub_type); + ScopePattern &SetSubType(const char *sub_type); + ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); + ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); + ScopePattern &AddScopeFeature(ScopeFeature feature); + + private: + class ScopePatternImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { + public: + ScopesResult(); + ScopesResult(ScopesResult const &result); + ScopesResult &operator=(ScopesResult const &result); + ~ScopesResult(); + + void SetScopes(std::vector &scopes); + void SetNodes(std::vector &nodes); + + private: + class ScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { + public: + ScopeBasePass(); + virtual ~ScopeBasePass(); + + protected: + // Subclasses implement respective fusion strategies and build the Patterns + virtual std::vector DefinePatterns() = 0; + // Define the name of the scope pass + virtual std::string PassName() = 0; + // Subclasses implement respective multi-scope or operator fusion methods across scopes + virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, + std::vector &results) = 0; + // Subclasses implement their own results and set the input and output of the final fusion operator + virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; + + private: + class ScopeBasePassImpl; + std::unique_ptr impl_; + friend class ge::ScopePassManager; + friend class ScopeBasePassImpl; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { + public: + using CreateFn = ScopeBasePass *(*)(); + ~ScopeFusionPassRegistry(); + + static ScopeFusionPassRegistry &GetInstance() { + static ScopeFusionPassRegistry instance; + return instance; + } + + ATTRIBUTED_DEPRECATED(void RegisterScopeFusionPass(const char *, CreateFn, bool)) + void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); + + void RegisterScopeFusionPass(const char *pass_name, CreateFn create_fn, bool is_general); + + private: + ScopeFusionPassRegistry(); + class ScopeFusionPassRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { + public: + ATTRIBUTED_DEPRECATED(static AscendString StringReplaceAll(const char *, const char *, const char *)) + static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); + static AscendString StringReplaceAll(const char *str, const char *old_value, const char *new_value); + static void FreeScopePatterns(ScopeFusionPatterns &patterns); + static void FreeOneBatchPattern(std::vector &one_batch_pattern); +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { + public: + ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); + ~ScopeFusionPassRegistrar() {} +}; + +#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ + static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ + ::ge::ScopeFusionPassRegistrar(pass_name, \ + []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, \ + is_general) +} // namespace ge + +#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ diff --git a/metadef/inc/graph/anchor.h b/metadef/inc/graph/anchor.h new file mode 100644 index 00000000..99fdeeb3 --- /dev/null +++ b/metadef/inc/graph/anchor.h @@ -0,0 +1,286 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_ANCHOR_H_ +#define INC_GRAPH_ANCHOR_H_ + +#include "graph/compiler_options.h" + +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/range_vistor.h" +#include "graph/types.h" + +namespace ge { +enum AnchorStatus { + ANCHOR_SUSPEND = 0, // dat null + ANCHOR_CONST = 1, + ANCHOR_DATA = 2, // Effective + ANCHOR_RESERVED = 3 +}; +using std::string; +using std::vector; + +class Node; + +using NodePtr = std::shared_ptr; + +class Edge; + +using EdgePtr = std::shared_ptr; + +class Anchor; + +using AnchorPtr = std::shared_ptr; + +class DataAnchor; + +using DataAnchorPtr = std::shared_ptr; + +class InDataAnchor; + +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; + +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; + +using ControlAnchorPtr = std::shared_ptr; + +class InControlAnchor; + +using InControlAnchorPtr = std::shared_ptr; + +class OutControlAnchor; + +using OutControlAnchorPtr = std::shared_ptr; + +using ConstAnchor = const Anchor; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this { + friend class AnchorUtils; + + public: + using TYPE = const char *; + template + using Vistor = RangeVistor>; + + Anchor(const NodePtr& ownerNode, int idx); + + virtual ~Anchor() = default; + + protected: + // Whether the two anchor is equal + virtual bool Equal(AnchorPtr anchor) const = 0; + virtual bool IsTypeOf(TYPE type) const; + + public: + // Get all peer anchors connected to current anchor + Vistor GetPeerAnchors() const; + // Get peer anchor size + size_t GetPeerAnchorsSize() const; + // Get first peer anchor + AnchorPtr GetFirstPeerAnchor() const; + + // Get the anchor belong to which node + NodePtr GetOwnerNode() const; + + // Remove all links with the anchor + void UnlinkAll() noexcept; + + // Remove link with the given anchor + graphStatus Unlink(const AnchorPtr &peer); + + // Replace peer with new peers + graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); + + // Judge if the anchor is linked with the given anchor + bool IsLinkedWith(const AnchorPtr &peer); + + // Get anchor index of the node + int GetIdx() const; + + // set anchor index of the node + void SetIdx(int index); + + protected: + // All peer anchors connected to current anchor + vector> peer_anchors_; + // The owner node of anchor + std::weak_ptr owner_node_; + // The index of current anchor + int idx_; + template + static Anchor::TYPE TypeOf() { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + return METADEF_FUNCTION_IDENTIFIER; + } + + public: + template + static std::shared_ptr DynamicAnchorCast(AnchorPtr anchorPtr) { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + if (anchorPtr == nullptr || !anchorPtr->IsTypeOf()) { + return nullptr; + } + return std::static_pointer_cast(anchorPtr); + } + + template + bool IsTypeOf() { + return IsTypeOf(TypeOf()); + } +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { + friend class AnchorUtils; + + public: + explicit DataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~DataAnchor() = default; + + protected: + bool IsTypeOf(TYPE type) const override; + + private: + Format format_{FORMAT_ND}; + AnchorStatus status_{ANCHOR_SUSPEND}; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { + friend class OutDataAnchor; + + friend class OutControlAnchor; + + public: + explicit InDataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~InDataAnchor() = default; + + // Get source out data anchor + OutDataAnchorPtr GetPeerOutAnchor() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkFrom(const OutDataAnchorPtr &src); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { + friend class InDataAnchor; + + friend class AnchorUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit OutDataAnchor(const NodePtr &ownerNode, int idx); + + virtual ~OutDataAnchor() = default; + // Get dst in data anchor(one or more) + Vistor GetPeerInDataAnchors() const; + uint32_t GetPeerInDataNodesSize() const; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + // Build connection from OutDataAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { + public: + explicit ControlAnchor(const NodePtr &ownerNode); + + explicit ControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~ControlAnchor() = default; + + protected: + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { + friend class OutControlAnchor; + + friend class OutDataAnchor; + + public: + explicit InControlAnchor(const NodePtr &ownerNode); + + explicit InControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~InControlAnchor() = default; + + // Get source out control anchors + Vistor GetPeerOutControlAnchors() const; + bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); } + + // Get source out data anchors + Vistor GetPeerOutDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkFrom(const OutControlAnchorPtr &src); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { + friend class InControlAnchor; + + public: + template + using Vistor = RangeVistor>; + + explicit OutControlAnchor(const NodePtr &ownerNode); + + explicit OutControlAnchor(const NodePtr &ownerNode, int idx); + + virtual ~OutControlAnchor() = default; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + // Get dst data anchor in control anchor(one or more) + Vistor GetPeerInDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + protected: + bool Equal(AnchorPtr anchor) const override; + bool IsTypeOf(TYPE type) const override; +}; +} // namespace ge +#endif // INC_GRAPH_ANCHOR_H_ diff --git a/metadef/inc/graph/attr_value_serializable.h b/metadef/inc/graph/attr_value_serializable.h new file mode 100644 index 00000000..3f291157 --- /dev/null +++ b/metadef/inc/graph/attr_value_serializable.h @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ +#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ + +#include +#include +#include "graph/ge_attr_value.h" +#include "graph/compiler_options.h" + +namespace ge { + +class GeAttrValue; +class _GeSerializable { + public: + template + struct ge_serializable_int64_t_support_type { + using DT = typename std::remove_cv::type; + static const bool value = std::is_same::value // by cast + || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + }; + + template + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom(t); + } + + template + static GeAttrValue SaveItemAsAttrValue(const vector &t) { + return GeAttrValue::CreateFrom(t); + } + + template = 0, typename DT = typename std::remove_cv::type> + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom
(t); + } + // int64_t support type + template ::value, int>::type = 0> + static GeAttrValue SaveItemAsAttrValue(const T &t) { + return GeAttrValue::CreateFrom(t); + } + // vector int64_t support type + template ::value, int>::type = 0> + static GeAttrValue SaveItemAsAttrValue(const vector &t) { + return GeAttrValue::CreateFrom(t); + } + + template + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template + static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template = 0, typename DT = typename std::remove_cv::type> + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue
(t); + } + + template ::value, int>::type = 0> + static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template ::value, int>::type = 0> + static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { + return attrVal.GetValue(t); + } + + template + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { + GeAttrValue itemVal = SaveItemAsAttrValue(item); + (void)namedAttrs.SetAttr(itemName, itemVal); + SaveItem(namedAttrs, args...); + } + + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs METADEF_ATTRIBUTE_UNUSED) {} + + template + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { + auto itemVal = namedAttrs.GetItem(itemName); + auto status = LoadItemFromAttrValue(item, itemVal); + if (status != GRAPH_SUCCESS) { + return status; + } + return LoadItem(namedAttrs, args...); + } + + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs METADEF_ATTRIBUTE_UNUSED) { return GRAPH_SUCCESS; } +}; + +#define _GE_FI(a) #a, a +#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) +#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) +#define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3) +#define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4) +#define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5) +#define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6) +#define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7) +#define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8) +#define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9) +#define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10) +#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11) +#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12) +#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) +#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) +#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ + _GE_FI(a1) \ + , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ + _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) + +#define _GE_PRIVATE_ARGS_GLUE(x, y) x y + +#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \ + ...) \ + N +#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args +#define _GE_COUNT_MACRO_VAR_ARGS(...) \ + _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) + +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) +#define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) + +#define _GE_INVOKE_VAR_MACRO(...) \ + _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \ + (__VA_ARGS__)) + +#define GE_SERIALIZABLE(...) \ + public: \ + friend class ge::GeAttrValue; \ + using __ge_serializable = int; \ + \ + private: \ + ge::graphStatus Save(GeAttrValue &ar) const { \ + GeAttrValue::NAMED_ATTRS named_attrs; \ + _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ + return ar.SetValue(named_attrs); \ + } \ + ge::graphStatus Load(const GeAttrValue &ar) { \ + GeAttrValue::NAMED_ATTRS named_attrs; \ + ge::graphStatus status = ar.GetValue(named_attrs); \ + if (status != GRAPH_SUCCESS) { \ + return status; \ + } \ + return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ + } + +// end NamedAttrs Helper: GE_SERIALIZABLE +} // namespace ge +#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/metadef/inc/graph/buffer.h b/metadef/inc/graph/buffer.h new file mode 100644 index 00000000..2ee67cfe --- /dev/null +++ b/metadef/inc/graph/buffer.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_BUFFER_H_ +#define INC_GRAPH_BUFFER_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/compiler_options.h" + +namespace ge { + +using std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { + public: + Buffer(); + Buffer(const Buffer &other); + + explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0); + + ~Buffer() = default; + + Buffer &operator=(const Buffer &other); + + static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize); + + const std::uint8_t *GetData() const; + std::uint8_t *GetData(); + std::size_t GetSize() const; + void ClearBuffer(); + + // For compatibility + inline const std::uint8_t *data() const { return GetData(); } + inline std::uint8_t *data() { return GetData(); } // lint !e659 + inline std::size_t size() const { return GetSize(); } + inline void clear() { return ClearBuffer(); } + uint8_t operator[](size_t index) const { // lint !e1022 !e1042 + if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 + return (uint8_t)(*buffer_)[index]; + } + return 0xff; + } + + private: + GeIrProtoHelper data_; + std::string *buffer_ = nullptr; + + // Create from protobuf obj + Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); + Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); + + friend class GeAttrValueImp; + friend class GeTensor; +}; +} // namespace ge +#endif // INC_GRAPH_BUFFER_H_ diff --git a/metadef/inc/graph/common_error_codes.h b/metadef/inc/graph/common_error_codes.h new file mode 100644 index 00000000..cdf9086f --- /dev/null +++ b/metadef/inc/graph/common_error_codes.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_COMMON_ERROR_CODES_H_ +#define INC_GRAPH_COMMON_ERROR_CODES_H_ + +#include "external/graph/ge_error_codes.h" + +namespace ge { +const graphStatus NO_DEPENDENCE_FUNC = 50331647; +const graphStatus NO_OVERLAP_DIM = 50331646; +const graphStatus NOT_SUPPORT_SLICE = 50331645; +} // namespace ge + +#endif // INC_GRAPH_COMMON_ERROR_CODES_H_ diff --git a/metadef/inc/graph/compiler_options.h b/metadef/inc/graph/compiler_options.h new file mode 100644 index 00000000..f31ad75c --- /dev/null +++ b/metadef/inc/graph/compiler_options.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_COMPILER_OPTIONS_H_ +#define INC_GRAPH_COMPILER_OPTIONS_H_ + +namespace ge { +#ifdef __GNUC__ +#define METADEF_ATTRIBUTE_UNUSED __attribute__((unused)) +#define METADEF_FUNCTION_IDENTIFIER __PRETTY_FUNCTION__ +#define METADEF_BUILTIN_PREFETCH(args_addr) __builtin_prefetch(args_addr) + +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif + +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +#else // WINDOWS +#define METADEF_ATTRIBUTE_UNUSED +#define METADEF_FUNCTION_IDENTIFIER __FUNCSIG__ +#define METADEF_BUILTIN_PREFETCH(args_addr) +#define GE_FUNC_HOST_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY +#endif +} // namespace ge + +#endif // INC_GRAPH_COMPILER_OPTIONS_H_ \ No newline at end of file diff --git a/metadef/inc/graph/compute_graph.h b/metadef/inc/graph/compute_graph.h new file mode 100644 index 00000000..d5c2d068 --- /dev/null +++ b/metadef/inc/graph/compute_graph.h @@ -0,0 +1,309 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ +#define INC_GRAPH_COMPUTE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/range_vistor.h" + +namespace ge { +class Node; +using NodePtr = std::shared_ptr; +class Edge; +using EdgePtr = std::shared_ptr; + +class InDataAnchor; +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; +using ControlAnchorPtr = std::shared_ptr; +class InControlAnchor; +using InControlAnchorPtr = std::shared_ptr; +class OutControlAnchor; +using OutControlAnchorPtr = std::shared_ptr; +class GeAttrValue; +using AttrValuePtr = std::shared_ptr; +using ConstComputeGraph = const ComputeGraph; + +class OperatorImpl; +using OperatorImplPtr = std::shared_ptr; + +class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { + friend class GraphUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit ComputeGraph(const std::string &name); + ~ComputeGraph() override; + + std::string GetName() const; + void SetName(const std::string &name); + + using AttrHolder::DelAttr; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + size_t GetAllNodesSize() const; + Vistor GetAllNodes() const; + // is_unknown_shape: false, same with GetAllNodes func + // is_unknown_shape: true, same with GetDirectNodes func + Vistor GetNodes(bool is_unknown_shape) const; + size_t GetDirectNodesSize() const; + Vistor GetDirectNode() const; + Vistor GetInputNodes() const; + Vistor GetOutputNodes() const; + + NodePtr FindNode(const std::string &name) const; + NodePtr FindFirstNodeMatchType(const std::string &name) const; + /*lint -e504*/ + // AddNode with NodePtr + NodePtr AddNode(NodePtr node); + NodePtr AddNode(OpDescPtr op); + NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize + NodePtr AddNodeFront(NodePtr node); + NodePtr AddNodeFront(const OpDescPtr &op); + NodePtr AddInputNode(NodePtr node); + NodePtr AddOutputNode(NodePtr node); + NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); + + graphStatus RemoveNode(const NodePtr &node); + graphStatus RemoveInputNode(const NodePtr &node); + graphStatus RemoveOutputNode(const NodePtr &node); + graphStatus RemoveConstInput(const NodePtr &node); + + /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, + /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph + /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() + /// must equal to subgraph->GetOwnerGraph(). + /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. + /// The subgraph's name SHOULD(not must) be the same as the parameter `name` + graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); + graphStatus AddSubgraph(const std::shared_ptr &subgraph); + + void RemoveSubgraph(const std::string &name); + void RemoveSubgraph(const std::shared_ptr &subgraph); + + std::shared_ptr GetSubgraph(const std::string &name) const; + std::vector> GetAllSubgraphs() const; + + // obsolete + std::shared_ptr AddSubGraph(std::shared_ptr sub_graph); + // obsolete + graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); + + /// + /// @brief Update input-mapping + /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input + /// @return graphStatus + /// + graphStatus UpdateInputMapping(const std::map &input_mapping); + + /// + /// @brief Update output-mapping + /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output + /// @return graphStatus + /// + graphStatus UpdateOutputMapping(const std::map &output_mapping); + + graphStatus TopologicalSorting(); + bool IsValid() const; + void InValid() { is_valid_flag_ = false; } + void Dump() const; + + void Swap(ComputeGraph &graph); + + graphStatus IsolateNode(const NodePtr &node); + graphStatus Verify(); + graphStatus InferShape(); + graphStatus InferOriginFormat(); + graphStatus InferShapeInNeed(); + graphStatus InsertEventNodes(); + bool operator==(const ComputeGraph &r_compute_graph) const; + + /*lint +e504*/ + const std::map, std::vector> &GetShareParamLayer() const { + return params_share_map_; + } + + void SetShareParamLayer(const std::map, std::vector> params_share_map) { + params_share_map_ = params_share_map; + } + + void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } + + void SetGraphOutNodes(std::map> out_nodes_map) { out_nodes_map_ = out_nodes_map; } + + void AppendGraphOutNodes(std::map> out_nodes_map) { + for (auto &item : out_nodes_map) { + (void)out_nodes_map_.emplace(item.first, item.second); + } + } + + shared_ptr GetParentGraph(); + void SetParentGraph(const shared_ptr &parent); + shared_ptr GetParentNode(); + void SetParentNode(const shared_ptr &parent); + + const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } + + void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } + + ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } + void SetOutputSize(uint32_t size) { output_size_ = size; } + uint32_t GetOutputSize() const { return output_size_; } + void SetInputSize(uint32_t size) { input_size_ = size; } + uint32_t GetInputSize() const { return input_size_; } + + // false: known shape true: unknow shape + bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } + void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } + + /// + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration is need iteration + /// + void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } + + void SetUserDefOutput(const std::string &output_name); + + const std::string GetOutput(); + + /// + /// Get is need train iteration. + /// @return is need iteration + /// + bool GetNeedIteration() const { return need_iteration_; } + + void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } + const std::map &GetGraphOpName() const { return op_name_map_; } + + const std::map &GetAllNodesInfo() const; + + void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } + + void SetGraphOutNodesInfo(std::vector> &out_nodes_info) { + output_nodes_info_ = out_nodes_info; + } + + void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { + output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); + } + + const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } + + void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { + target_nodes_info_ = target_nodes_info; + } + const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } + + void SetSessionID(uint64_t session_id) { session_id_ = session_id; } + uint64_t GetSessionID() const { return session_id_; } + + void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } + uint32_t GetGraphID() const { return graph_id_; } + + void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } + ge::Format GetDataFormat() const { return data_format_; } + bool IsSummaryGraph() const { return is_summary_graph_; } + void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } + // Graph Before BFE + ComputeGraphPtr origGraph_; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::vector &stack, bool reverse); + graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::deque &stack); + graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, + std::map &breadth_node_map); + /// nodes like : (a) <--- (c) ---> (b) + /// node a and b have only one parent node c, and a is connected to c firstly + /// topo order of DFS is `c, b, a` with `dfs_reverse=false` as default + /// in same case, user could get `c, a, b` with `dfs_reverse=true` + graphStatus TopologicalSortingGraph(bool dfs_reverse = false); + graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); + Vistor AllGraphNodes(std::vector> &subgraphs) const; + size_t GetInEdgeSize(const NodePtr &node); + size_t GetOutEdgeSize(const NodePtr &node); + graphStatus RemoveExtraOutEdge(const NodePtr &node); + bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; + bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; + bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, + const std::vector &l_node_ptr_vector) const; + + void SetNodesOwner(); + + friend class ModelSerializeImp; + friend class GraphDebugImp; + friend class OnnxUtils; + friend class TuningUtils; + + std::string name_; + uint32_t graph_id_ = 0; + ProtoAttrMapHelper attrs_; + std::vector nodes_; + std::map all_nodes_infos_; + std::vector target_nodes_info_; + + std::vector input_nodes_; + std::vector inputs_order_; + uint32_t input_size_ = 1; + std::map> out_nodes_map_; + uint32_t output_size_ = 1; + std::vector> output_nodes_info_; + + std::vector> sub_graph_; + std::map> names_to_subgraph_; + std::weak_ptr parent_graph_; + std::weak_ptr parent_node_; + + // the members followed should not in the ComputeGraph class + bool is_valid_flag_; + bool is_summary_graph_ = false; + // Indicates whether it is need iteration + bool need_iteration_ = false; + std::map, std::vector> params_share_map_; + // TaskIdx -> op_name Map + std::map op_name_map_; + uint64_t session_id_ = 0; + ge::Format data_format_ = ge::FORMAT_ND; + // unknown graph indicator, default is false, mean known shape + bool is_unknown_shape_graph_ = false; +}; +} // namespace ge +#endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/metadef/inc/graph/debug/ge_attr_define.h b/metadef/inc/graph/debug/ge_attr_define.h new file mode 100644 index 00000000..47c8c93b --- /dev/null +++ b/metadef/inc/graph/debug/ge_attr_define.h @@ -0,0 +1,1160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ +#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ + +/*lint -e618*/ +#include +#include "graph/types.h" +#include "graph/compiler_options.h" + +namespace ge { +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SCALE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WINDOWS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GLOBAL_POOLING; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_NORM_REGION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_LOCAL_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROADCAST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TIDX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TPADDINGS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_W; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TMULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_T; + +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string ATTR_NAME_N; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TSHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_RELATED_AIPP_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_AIPP_DATA_NAME_MAP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_OP_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFERRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; + + + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTS_LABEL_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ROOT_GRAPH_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; + +// to be deleted +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_OCR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; + + + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; + +// _Arg +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INDEX; +// _RetVal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETVAL_ATTR_NAME_INDEX; +// Data +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DATA_ATTR_NAME_DATA_TYPE; + +// Send +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_EVENT_ID; + +// Recv +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_EVENT_ID; + +// Convolution +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATIONS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_GROUP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_NUM_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_KERNEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ADJ; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_TARGET_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BEFORE_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_HAS_BIAS; + +// Pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAN_OPT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAME_ALGO; + +// Eltwise +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_COEFF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_WEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_BETA; + +// BatchNorm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; + +// Huberloss +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; + +// SSDRealDivTileMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; + +// SSDSumMulRealDivMean +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; +/// ConcatFive2Four +/// ConcatFour2Five +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; +// Scale +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; + +// FullConnection +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_FILTER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_ATTR_NAME_ALGO; + +// SoftmaxOpParams +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_ALGO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_MODE; + +// SparseSoftmaxCrossEntropy +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; +// Attr labelSmoothing +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; + +// ApplyMomentum +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION; + +// Activation +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_COEF; + +// Concat +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_ATTR_NAME_AXIS; + +// Const +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +// Roipooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; + +// DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; +// Ssd DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ETA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; + +// Refinedet DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; + +// Yolo DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ClASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BIASES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_RELATIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; + +// DetectionPostprocess +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; + +// Spatialtransfrom +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; + +// Proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_W; +// Softmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_AXIS; + +// Permute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; + +// SSD Normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_EPS; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_END_AXIS; + +// SsdPRIORBOX +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_FLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; + +// PRelu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PRELU_ATTR_CHANNEL_SHARED; + +// Psroi pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_GROUP_SIZE; + +// Power +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_POWER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; + +// Log +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; +// Pack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; +// Unpack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; +// Gathernd +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TPARAMS; + +// Argmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; +// Relu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; + +// FreeSpaceExtract +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; + +// Split +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SLICE_POINT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_NUM_SPLIT; + +// Tvm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; + +// Squeeze +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_OP_NAME; + +// Stride slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_END_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; + +// Slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_BEGINS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_SIZES; + +// Roialign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; + +// Generate_rpn_proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; +// Decode_bbox +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; + +// Cast +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DSTT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_SRCT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DST_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_TRUNCATE; + +// Fastrcnnn predications +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; + +// REORG +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_REVERSE; + +// MERGE +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; +static const std::string NOT_NET_OUTPUT = "not_net_output"; + +// ENTER +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_CONSTANT_FLAG; + +// Concatv2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_N; +// SUM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_KEEP_DIMS; + +// ResizeBilinear +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_HEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_WIDTH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_END; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; + +// RetinaNet +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; +// MatMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_HAS_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_ATTR_IS_TRAINING; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_START_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_END_AXIS; + +// Reshape +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NUM_AXES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_BETA; + +// Frameoworkop +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_IN_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_OUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_C; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_DEPTH_CONV; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_CONV; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ANN_MEAN_KEEPDIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_PADDINGDS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_CONSTANT_VALUE; + +// ConvGradFilter +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; +// ConvGradInput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; + +// Rnn +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_MODE_STATIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MUTI_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CELL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CNN_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; + +// PadV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; + +// MirrorPad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; +// Filler +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; + +// Shufflechannel +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHUFFLE_CHANNEL_GROUP; + +// TopKV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TOPKV2_ATTR_K; + +// Calibaration +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_H_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_W_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_TOP_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_BOTTOM_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_RIGHT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CLASS_NUM; +// Model +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TARGET_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_STREAM_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; + +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BYTE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_INFERENCE_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_OPDEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IO_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_SCOPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUFLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEQLEN_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_X_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONT_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_XSTATIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_MINI; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_TINY; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_LITE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; + +// Used for operators that do not generate task +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; + +// Used for operators that output reuse input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; + + + +// L2_normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +// Log time stamp +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; +// SpaceToDepth/DepthToSpace +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; + +// SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; + +// MaxPoolGradWithArgmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; + +// AvgPoolGrad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; + +// Varible +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; + +// Assign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; + +// ShapeN +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; + +// Space2bacth batch2space +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; +// Depth_to_space space_to_depth +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; +// FakeQuantWithMinMaxVars +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; +// Mobilenet_ssd_conv_fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; + +// Lsh project +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; + +// GatherV2 attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; + +// Reshape attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; + +// Axis attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; +// The node link with SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; +// For constant folding +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; + +// Used for mark the active label list to find stream of activated node +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; + +// Multi batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIG_NODE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; + +// Function Op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; + +// Used for mark the active node is for loop, type:bool +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; + +// Atomic addr clean attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_FUSION_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_ATOMIC_NODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET; +// Used for find variable session_id +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MODEL_ATTR_SESSION_ID; + +// Source/dst format for Op FormatTransfer +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_FORMAT; + +// For compile op by ge call +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEED_COMPILE; + +// For mutil-batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; + +// For inserted op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; + +// For compress weight +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; + +// For data dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; + +// used for lX fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_SLICE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_RECOVER_ATTR; + +// used for memory allocate +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MEM_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_P2P_MEMORY_SIZE; + +// for unregistered op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; + +// op overflow dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; + +// op dynamic input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_END; + +// functional ops attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; + +// used for label switch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; + +// Variable +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; + +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; +// used for LX tiling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; + +// Used for support Horovod +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; +// for gradient group +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; + +// dynamic shape attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_SINGLE_AICPU; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES; + +// for fusion op plugin +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +// graph partition for aicpu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; + +// input and output memory type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SPECIAL_OUTPUT_SIZE; + +// stage +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_STAGE_LEVEL; + +// input_output_offset +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; + +// The processing mode of INF and NAN during floating-point number calculation. +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_FP_CEILING_MODE; +// count of data from getnext_sink +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_DATA_COUNT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_SHAPE_INFO; + +// getnext_sink marked on NetOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_DYNMAIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ALL_GEARS_INFO; + +// Calculate the operator output memory +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_SIZE_CALC_TYPE; +} // namespace ge + +/*lint +e618*/ +#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/metadef/inc/graph/def_types.h b/metadef/inc/graph/def_types.h new file mode 100644 index 00000000..cd5e19f4 --- /dev/null +++ b/metadef/inc/graph/def_types.h @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DEF_TYPES_H_ +#define INC_GRAPH_DEF_TYPES_H_ + +#include +#include +#include +#include "graph/attr_value_serializable.h" +#include "graph/buffer.h" +namespace ge { +#define DEF_TYPE_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + type *mutable_##name() { return &name; } + +#define DEF_TYPE_HAS_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + \ + private: \ + bool has_mutable_##name{false}; \ + \ + public: \ + bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ + type *mutable_##name() { \ + has_mutable_##name = true; \ + return &name; \ + } + +#define DEF_TYPE_VEC_DEC(type, name) \ + inline int name##_size() const { return name.size(); } \ + inline void clear_##name() { name.clear(); } \ + inline void set_##name(int index, type value) { name[index] = value; } \ + inline void add_##name(type value) { name.push_back(value); } \ + inline std::vector *mutable_##name() { return &name; } + +#define DEF_TYPE_BYTES_DEC(name) \ + inline void clear_##name() { name.ClearBuffer(); } \ + inline void set_##name(const void *value, size_t size) { \ + name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ + inline Buffer *mutable_##name() { return &name; } + +struct CompressInfo { + public: + CompressInfo() {} + CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK, + int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) { + blockrow = blockRow; + blockcol = blockCol; + fractalk = fractalK; + fractaln = fractalN; + lastfractalk = lastFractalK; + lastfractaln = lastFractalN; + cubesize = cubeSize; + loaddir = loadDir; + } + + int32_t blockrow{0}; // Block row + int32_t blockcol{0}; // Block col + int32_t fractalk{0}; // Fractal K + int32_t fractaln{0}; // Fractal N + int32_t lastfractalk{0}; // K of last fractal + int32_t lastfractaln{0}; // N of last fractal + int32_t cubesize{0}; // Cube's length + int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load + DEF_TYPE_DEC(int32_t, blockrow); + DEF_TYPE_DEC(int32_t, blockcol); + DEF_TYPE_DEC(int32_t, fractalk); + DEF_TYPE_DEC(int32_t, fractaln); + DEF_TYPE_DEC(int32_t, lastfractalk); + DEF_TYPE_DEC(int32_t, lastfractaln); + DEF_TYPE_DEC(int32_t, cubesize); + DEF_TYPE_DEC(int32_t, loaddir); + + GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir); +}; + +enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 }; +enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 }; +enum QuantizeAlgorithm { + NON_OFFSET_ALGO = 0, + HALF_OFFSET_ALGO = 1, + ALL_OFFSET_ALGO = 2, +}; +struct QuantizeFactor { + public: + // QuantizeScaleMode scale_mode; + uint32_t scale_mode{0}; + Buffer scale_value; + int64_t scale_offset{0}; + Buffer offset_data_value; + int64_t offset_data_offset{0}; + Buffer offset_weight_value; + int64_t offset_weight_offset{0}; + Buffer offset_pad_value; + int64_t offset_pad_offset{0}; + + DEF_TYPE_DEC(uint32_t, scale_mode); + DEF_TYPE_BYTES_DEC(scale_value); + + DEF_TYPE_DEC(int64_t, scale_offset); + DEF_TYPE_BYTES_DEC(offset_data_value); + DEF_TYPE_DEC(int64_t, offset_data_offset); + + DEF_TYPE_BYTES_DEC(offset_weight_value); + DEF_TYPE_DEC(int64_t, offset_weight_offset); + DEF_TYPE_BYTES_DEC(offset_pad_value); + DEF_TYPE_DEC(int64_t, offset_pad_offset); + + GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value, + offset_weight_offset, offset_pad_value, offset_pad_offset) +}; + +static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) { + return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 || + factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0; +} + +struct AllOffsetQuantizeInfo { + public: + AllOffsetQuantizeInfo() {} + AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {} + float scale{0}; + int32_t offset{0}; + + DEF_TYPE_DEC(float, scale); + DEF_TYPE_DEC(int32_t, offset); + + GE_SERIALIZABLE(scale, offset) +}; + +struct QuantizeCalcFactor { + public: + Buffer offsetw; + int64_t offsetw_offset{0}; + Buffer offsetd; + int64_t offsetd_offset{0}; + Buffer scalereq; + int64_t scaledreq_offset{0}; + Buffer offsetdnext; + int64_t offsetdnext_offset{0}; + + DEF_TYPE_BYTES_DEC(offsetw); + DEF_TYPE_DEC(int64_t, offsetw_offset); + DEF_TYPE_BYTES_DEC(offsetd); + DEF_TYPE_DEC(int64_t, offsetd_offset); + DEF_TYPE_BYTES_DEC(scalereq); + DEF_TYPE_DEC(int64_t, scaledreq_offset); + DEF_TYPE_BYTES_DEC(offsetdnext); + DEF_TYPE_DEC(int64_t, offsetdnext_offset); + + GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext, + offsetdnext_offset); +}; + +static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) { + return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 || + factor.offsetdnext.GetSize() > 0; +} + +struct QuantizeFactorParams { + uint32_t quantize_algo{0}; + uint32_t scale_type{0}; + QuantizeFactor quantize_param; + QuantizeFactor dequantize_param; + QuantizeFactor requantize_param; + QuantizeCalcFactor quantizecalc_param; + DEF_TYPE_DEC(uint32_t, quantize_algo); + DEF_TYPE_DEC(uint32_t, scale_type); + DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param); + DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param); + DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param); + DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param); + + GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param, + has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param, + has_mutable_quantizecalc_param); +}; + +#undef DEF_TYPE_DEC +} // namespace ge + +#endif // INC_GRAPH_DEF_TYPES_H_ diff --git a/metadef/inc/graph/detail/any_map.h b/metadef/inc/graph/detail/any_map.h new file mode 100644 index 00000000..8c839af4 --- /dev/null +++ b/metadef/inc/graph/detail/any_map.h @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ +#define INC_GRAPH_DETAIL_ANY_MAP_H_ + +#include +#include +#include +#include + +#include "graph/compiler_options.h" + +namespace ge { +using std::shared_ptr; +using std::string; + +class TypeID { + public: + template + static TypeID Of() { + return TypeID(METADEF_FUNCTION_IDENTIFIER); + } + + ~TypeID() = default; + + bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } + + private: + explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 + + string type_; +}; + +class AnyMap { + public: + template + bool Set(const string &name, const DT &val); + + template + bool Get(const string &name, T &retValue) const; + + bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } + + void Swap(AnyMap &other) { + anyValues_.swap(other.anyValues_); + } + + private: + class Placeholder { + public: + virtual ~Placeholder() = default; + + virtual const TypeID &GetTypeInfo() const = 0; + }; + + template + class Holder : public Placeholder { + public: + explicit Holder(const VT &value) : value_(value) {} + + ~Holder() override = default; + + const TypeID &GetTypeInfo() const override { + static const TypeID typeId = TypeID::Of(); + return typeId; + } + + const VT value_; + }; + + std::map> anyValues_; +}; + +template +bool AnyMap::Set(const string &name, const DT &val) { + auto it = anyValues_.find(name); + + std::shared_ptr> tmp; + try { + tmp = std::make_shared>(val); + } catch (std::bad_alloc &e) { + tmp = nullptr; + } catch (...) { + tmp = nullptr; + } + + if (it == anyValues_.end()) { + (void)anyValues_.emplace(name, tmp); + } else { + if (it->second && it->second->GetTypeInfo() == TypeID::Of
()) { + it->second = tmp; + } else { + return false; + } + } + return true; +} + +template +bool AnyMap::Get(const string &name, T &retValue) const { + auto it = anyValues_.find(name); + if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of()) { + auto retPtr = std::static_pointer_cast>(it->second); + retValue = retPtr->value_; + return true; + } + return false; +} +} // namespace ge +#endif // INC_GRAPH_DETAIL_ANY_MAP_H_ diff --git a/metadef/inc/graph/detail/attributes_holder.h b/metadef/inc/graph/detail/attributes_holder.h new file mode 100644 index 00000000..0273ce99 --- /dev/null +++ b/metadef/inc/graph/detail/attributes_holder.h @@ -0,0 +1,165 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ +#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/detail/any_map.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" + +namespace google { +namespace protobuf { +class Message; +template +class Map; +} // namespace protobuf +} // namespace google + +namespace ge { +using std::string; +class GeAttrValue; + +namespace proto { +class AttrDef; +class TensorDef; +class TensorDescriptor; +class ShapeDef; +class NamedAttrs; +class ModelDef; +class OpDef; +class GraphDef; +} // namespace proto + +using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 +using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; + +template +class GeIrProtoHelper { + public: + GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg) + : protoOwner_(protoOwner), protoMsg_(protoMsg) {} + + GeIrProtoHelper() { + protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); + protoMsg_ = nullptr; + } + virtual ~GeIrProtoHelper() = default; + + template + GeIrProtoHelper(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOwner_; + protoMsg_ = other.protoMsg_; + } + template + GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOnwer_; + protoMsg_ = other.protoMsg_; + return *this; + } + void InitDefault(); + template + bool operator==(const GeIrProtoHelper &other) const { + return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_; + } + + inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; } + inline ProtoType *GetProtoMsg() const { return protoMsg_; } + void CopyValueFrom(const GeIrProtoHelper &other) { + if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { + *protoMsg_ = *other.protoMsg_; + } + } + void MoveValueFrom(GeIrProtoHelper &&other) { + if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { + *protoMsg_ = std::move(*other.protoMsg_); + } + } + + void Swap(GeIrProtoHelper &other) { + protoOwner_.swap(other.protoOwner_); + + ProtoType *temp = protoMsg_; + protoMsg_ = other.protoMsg_; + other.protoMsg_ = temp; + } + + // protoMsg_ is part of protoOwner_, they have the same runtime + ProtoMsgOwner protoOwner_ = nullptr; + ProtoType *protoMsg_ = nullptr; + friend class GeIrProtoHelper::value, typename std::remove_const::type, const ProtoType>::type>; +}; + +using ProtoAttrMapHelper = GeIrProtoHelper; +using ConstProtoAttrMapHelper = GeIrProtoHelper; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { + public: + AttrHolder() = default; + virtual ~AttrHolder() = default; + + graphStatus SetAttr(const string &name, const GeAttrValue &value); + + graphStatus GetAttr(const string &name, GeAttrValue &value) const; + + bool HasAttr(const string &name) const; + + graphStatus DelAttr(const string &name); + + void CopyAttrsFrom(const AttrHolder &holder); + + void Swap(AttrHolder &holder) { + requiredAttrs_.swap(holder.requiredAttrs_); + extAttrs_.Swap(holder.extAttrs_); + } + + template + bool SetExtAttr(const string &name, const T &value) { + return extAttrs_.Set(name, value); + } + template + T TryGetExtAttr(const string &name, T defaultValue) const { + T ret(defaultValue); + (void)extAttrs_.Get(name, ret); + return ret; + } + + protected: + graphStatus AddRequiredAttr(const std::string &name); + const std::unordered_set GetAllAttrNames() const; + const std::map GetAllAttrs() const; // lint !e1073 + + virtual ProtoAttrMapHelper MutableAttrMap() = 0; + virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; + + friend class ModelSerializeImp; + friend class AttrUtils; + friend class AttrUtilsHelper; + + std::vector requiredAttrs_; + + private: + AnyMap extAttrs_; +}; +} // namespace ge +#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/metadef/inc/graph/detail/model_serialize_imp.h b/metadef/inc/graph/detail/model_serialize_imp.h new file mode 100644 index 00000000..bc79c4c2 --- /dev/null +++ b/metadef/inc/graph/detail/model_serialize_imp.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ +#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ + +#include +#include +#include +#include +#include "graph/anchor.h" +#include "graph/detail/attributes_holder.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +using ComputeGraphPtr = std::shared_ptr; + +struct NodeNameGraphReq { + string node_name; + int32_t index; + ComputeGraphPtr graph; +}; + +struct NodeNameNodeReq { + string src_node_name; + int32_t src_out_index; + NodePtr dst_node; + int32_t dst_in_index; + string dst_node_name; +}; + +class ModelSerializeImp { + public: + bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); + + bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); + + bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); + + bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); + + bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); + + bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); + + bool UnserializeModel(Model &model, proto::ModelDef &modeProto); + + bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto); + + bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto); + + bool HandleNodeNameRef(); + + bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); + void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, + std::vector &value_in, std::vector &value_out, std::vector &opt); + void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); + + bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); + + bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto); + + bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index); + + void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } + + private: + bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs); + + std::vector graph_input_node_names_; + std::vector graph_output_node_names_; + std::vector node_input_node_names_; + std::map node_map_; + ProtoMsgOwner protobuf_owner_; +}; +} // namespace ge + +#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/metadef/inc/graph/ge_attr_value.h b/metadef/inc/graph/ge_attr_value.h new file mode 100644 index 00000000..c96cf591 --- /dev/null +++ b/metadef/inc/graph/ge_attr_value.h @@ -0,0 +1,344 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_ATTR_VALUE_H_ +#define INC_GRAPH_GE_ATTR_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/buffer.h" +#include "detail/attributes_holder.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" + +using std::map; +using std::string; +using std::vector; + +namespace ge { +class GeTensor; + +using GeTensorPtr = std::shared_ptr; +using ConstGeTensorPtr = std::shared_ptr; + +class ComputeGraph; +using ComputeGraphPtr = std::shared_ptr; +using ConstComputeGraphPtr = std::shared_ptr; + +class GeTensorDesc; +class GeAttrValue; +class GeAttrValueImp; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { + public: + NamedAttrs(); + virtual ~NamedAttrs() = default; + void SetName(const std::string &name); + string GetName() const; + GeAttrValue GetItem(const string &key) const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + // Create namedAttrs from protobuf obj + NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); + GeIrProtoHelper named_attrs_; + friend class GeAttrValueImp; + friend class GeAttrValue; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { + public: + using INT = int64_t; + using FLOAT = float; + using BOOL = bool; + using STR = std::string; + using TENSOR = GeTensorPtr; + using TENSOR_DESC = GeTensorDesc; + using GRAPH = ComputeGraphPtr; + using BYTES = Buffer; + using NAMED_ATTRS = ge::NamedAttrs; + using DATA_TYPE = ge::DataType; + + using LIST_INT = vector; + using LIST_FLOAT = vector; + using LIST_BOOL = vector; + using LIST_STR = vector; + using LIST_TENSOR = vector; + using LIST_TENSOR_DESC = vector; + using LIST_GRAPH = vector; + using LIST_BYTES = vector; + using LIST_NAMED_ATTRS = vector; + using LIST_LIST_INT = vector>; + using LIST_DATA_TYPE = vector; + + using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). + + enum ValueType { + VT_NONE = 0, + VT_STRING, + VT_FLOAT, + VT_BOOL, + VT_INT, + VT_TENSOR_DESC, + VT_TENSOR, + VT_BYTES, + VT_GRAPH, + VT_NAMED_ATTRS, + VT_LIST_LIST_INT, + VT_DATA_TYPE, + + VT_LIST_BASE = 1000, + VT_LIST_STRING = VT_LIST_BASE + VT_STRING, + VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT, + VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL, + VT_LIST_INT = VT_LIST_BASE + VT_INT, + VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC, + VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR, + VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES, + VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH, + VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS, + VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE, + }; + + template + struct IsAttrTypeEnable { + using DT = typename std::remove_cv::type; + + static bool const VALUE = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + + // Not has list type of NamedAttrs + static bool const LIST_VALUE = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value; + }; + + template + // To cols + using enable_if_vector_type_valid_t = typename std::enable_if::LIST_VALUE, + int>::type; + + template + using enable_if_one_type_valid_t = typename std::enable_if::VALUE, int>::type; + + template + using enable_if_type_valid_t = + typename std::enable_if::VALUE || IsAttrTypeEnable::LIST_VALUE, int>::type; + + template + using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; + + GeAttrValue(); + ~GeAttrValue() = default; + // SetValue, Set initializer_list + template = 0> + graphStatus SetValue(std::initializer_list
&&val) { + T vectorVal; + for (auto &item : val) { + vectorVal.push_back(item); + } + return SetValue(vectorVal); + } + + // SetValue, Set vector + template = 0> + graphStatus SetValue(const std::vector
&val) { + T vectorVal; + for (auto item : val) { + vectorVal.push_back(item); + } + return SetValue(vectorVal); + } + + // SetValue, not list type + template = 0> + graphStatus SetValue(DT &&val) { + return SetValue(T(std::forward
(val))); + } + + // GE_SERIALIZABLE + template = 0> + graphStatus SetValue(const T &t) { + return t.Save(*this); + } + + template = 0> + graphStatus SetValue(const vector &t) { + vector attrs; + for (auto &item : t) { + GeAttrValue val; + item.Save(val); + NamedAttrs attrsItem; + (void)val.GetValue(attrsItem); + attrs.push_back(attrsItem); + } + return SetValue(attrs); + } + + // GetValue, list value + template = 0, + typename std::enable_if::value, int>::type = 0> + graphStatus GetValue(std::vector
&val) const { + T valGet; + val.clear(); + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + for (auto item : valGet) { + val.push_back(item); + } + return GRAPH_SUCCESS; + } + + // GetValue, not list type + template = 0, + typename std::enable_if::value, int>::type = 0> + graphStatus GetValue(DT &val) const { + T valGet; + auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + val = DT(valGet); + return GRAPH_SUCCESS; + } + + // GE_SERIALIZABLE + template = 0> + graphStatus GetValue(T &t) { + return t.Load(*this); + } + + template = 0> + graphStatus GetValue(vector &t) { + graphStatus status; + t.clear(); + vector attrs; + status = this->GetValue(attrs); + if (status != GRAPH_SUCCESS) { + return status; + } + for (auto &attr : attrs) { + T item; + GeAttrValue val; + (void)val.SetValue(attr); + status = item.Load(val); + if (status != GRAPH_SUCCESS) { + return status; + } + t.push_back(item); + } + return GRAPH_SUCCESS; + } + + template = 0> + static GeAttrValue CreateFrom(DT &&val) { + GeAttrValue valRet; + (void)valRet.SetValue(std::forward
(val)); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(std::initializer_list
&&val) { + GeAttrValue valRet; + (void)valRet.SetValue(std::move(val)); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(const T &val) { + GeAttrValue valRet; + (void)valRet.SetValue(val); + return valRet; + } + + template = 0> + static GeAttrValue CreateFrom(const vector &val) { + GeAttrValue valRet; + (void)valRet.SetValue(val); + return valRet; + } + + ValueType GetValueType() const; + + bool IsEmpty() const; + + GeAttrValue Copy() const; + + // For map key + bool operator==(const GeAttrValue &other) const { return value_ == other.value_; } + + graphStatus MutableTensor(GeTensorPtr &tensor); + graphStatus MutableListTensor(vector &list_tensor); + + private: +#define VALUE_SET_GET_DEC(DT) \ + graphStatus SetValue(const DT &val); \ + graphStatus GetValue(DT &val) const; + VALUE_SET_GET_DEC(GeAttrValue::STR) + VALUE_SET_GET_DEC(GeAttrValue::INT) + VALUE_SET_GET_DEC(GeAttrValue::FLOAT) + VALUE_SET_GET_DEC(GeAttrValue::BOOL) + VALUE_SET_GET_DEC(GeTensorDesc) + VALUE_SET_GET_DEC(GeAttrValue::TENSOR) + VALUE_SET_GET_DEC(GeAttrValue::GRAPH) + VALUE_SET_GET_DEC(BYTES) + VALUE_SET_GET_DEC(NamedAttrs) + VALUE_SET_GET_DEC(ge::DataType) // lint !e665 + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector>) //lint !e665 + VALUE_SET_GET_DEC(vector) //lint !e665 +#undef VALUE_SET_GET_DEC + + GeIrProtoHelper value_; + GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val); + + friend class AttrHolder; + friend class ModelSerializeImp; + friend class OnnxUtils; +}; + +class AttrValueImpl { + public: + AttrValueImpl() = default; + ~AttrValueImpl() = default; + + GeAttrValue geAttrValue_; +}; +} // namespace ge +#endif // INC_GRAPH_GE_ATTR_VALUE_H_ diff --git a/metadef/inc/graph/ge_context.h b/metadef/inc/graph/ge_context.h new file mode 100644 index 00000000..3f3d0a8e --- /dev/null +++ b/metadef/inc/graph/ge_context.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_GRAPH_GE_CONTEXT_H_ +#define INC_GRAPH_GE_CONTEXT_H_ + +#include +#include "graph/ge_error_codes.h" + +namespace ge { +class GEContext { + public: + graphStatus GetOption(const std::string &key, std::string &option); + bool GetHostExecFlag(); + uint64_t SessionId(); + uint32_t DeviceId(); + uint64_t TraceId(); + void Init(); + void SetSessionId(uint64_t session_id); + void SetCtxDeviceId(uint32_t device_id); + private: + thread_local static uint64_t session_id_; + uint32_t device_id_ = 0; + uint64_t trace_id_ = 0; +}; // class GEContext + +/// Get context +/// @return +GEContext &GetContext(); +} // namespace ge + +#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/metadef/inc/graph/ge_global_options.h b/metadef/inc/graph/ge_global_options.h new file mode 100644 index 00000000..0abf391e --- /dev/null +++ b/metadef/inc/graph/ge_global_options.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ +#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ + +#include +#include + +namespace ge { +std::map &GetMutableGlobalOptions(); +} +#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/metadef/inc/graph/ge_local_context.h b/metadef/inc/graph/ge_local_context.h new file mode 100644 index 00000000..a691ebde --- /dev/null +++ b/metadef/inc/graph/ge_local_context.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ +#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ + +#include +#include +#include +#include "graph/ge_error_codes.h" + +using std::string; +using std::map; + +namespace ge { +class GEThreadLocalContext { + public: + graphStatus GetOption(const string &key, string &option); + void SetGraphOption(map options_map); + void SetSessionOption(map options_map); + void SetGlobalOption(map options_map); + + map GetAllGraphOptions() const; + map GetAllSessionOptions() const; + map GetAllGlobalOptions() const; + map GetAllOptions() const; + + private: + map graph_options_; + map session_options_; + map global_options_; +}; // class GEThreadLocalContext + +GEThreadLocalContext &GetThreadLocalContext(); +} // namespace ge +#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/metadef/inc/graph/ge_tensor.h b/metadef/inc/graph/ge_tensor.h new file mode 100644 index 00000000..a7688a47 --- /dev/null +++ b/metadef/inc/graph/ge_tensor.h @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GE_TENSOR_H_ +#define INC_GRAPH_GE_TENSOR_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/buffer.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { + public: + GeShape(); + ~GeShape() = default; + explicit GeShape(std::vector s); + + size_t GetDimNum() const; + // If the idx is invalid, return 0 + int64_t GetDim(size_t idx) const; + graphStatus SetDim(size_t idx, int64_t value); + std::vector GetDims() const; + + int64_t GetShapeSize() const; + std::string ToString() const; + + /// + /// @brief Check is unknown shape + /// @return bool + /// + bool IsUnknownShape() const; + + /// + /// @brief Check is a scalar + /// @return bool + /// + bool IsScalar() const; + + GeShape(const GeShape &other); + GeShape(GeShape &&other); + GeShape &operator=(const GeShape &other); + GeShape &operator=(GeShape &&other); + + private: + GeIrProtoHelper shape_def_; + friend class GeTensorDesc; + // Create from proto obj + GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); + + void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { + friend class TensorUtils; + friend class GeAttrValue; + friend class ModelSerialize; + + public: + GeTensorDesc(); + explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + GeTensorDesc(const GeTensorDesc &desc); + GeTensorDesc(GeTensorDesc &&desc); + + ~GeTensorDesc() = default; + bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; + + void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + + GeShape GetShape() const; + GeShape &MutableShape(); + void SetShape(GeShape shape); + + // set shape with -2, it stand for unknown shape + void SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + + GeShape GetOriginShape() const; + void SetOriginShape(const GeShape &originShape); + + Format GetFormat() const; + void SetFormat(Format format); + + Format GetOriginFormat() const; + void SetOriginFormat(Format originFormat); + + void SetName(const std::string &name); + const std::string GetName() const; + + DataType GetDataType() const; + void SetDataType(DataType dt); + + DataType GetOriginDataType() const; + void SetOriginDataType(DataType originDataType); + + std::vector GetRefPortIndex() const; + void SetRefPortByIndex(const std::vector &index); + + GeTensorDesc Clone() const; + GeTensorDesc &operator=(const GeTensorDesc &desc); + GeTensorDesc &operator=(GeTensorDesc &&desc); + + graphStatus IsValid() const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + void Init(); + + // Create from proto obj + GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); + friend class GeTensor; + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class OnnxUtils; + + GeIrProtoHelper tensor_descriptor_; + // Reference from tensorDescriptor_, do not direct use + mutable GeShape __shape_; + + void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; } + GeShape &ShapeReference() const; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { + public: + GeTensor(); + explicit GeTensor(const GeTensorDesc &tensorDesc); + explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector &data); + explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data); + explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); + explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector &&data); + ~GeTensor() = default; + + GeTensorDesc GetTensorDesc() const; + GeTensorDesc &MutableTensorDesc(); + void SetTensorDesc(const GeTensorDesc &tensorDesc); + + const Buffer GetData() const; + Buffer MutableData(); + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const Buffer &data); + graphStatus SetData(const uint8_t *data, size_t size); + + GeTensor Clone() const; + + // Share value + GeTensor(const GeTensor &other); + // Share value + GeTensor &operator=(const GeTensor &other); + + private: + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class OnnxUtils; + // Create from proto obj + GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); + GeIrProtoHelper tensor_def_; + // Reference from tensorDef_, do not direct use + mutable GeTensorDesc __desc_; + GeTensorDesc &DescReference() const; +}; +} // namespace ge +#endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/metadef/inc/graph/graph_util.h b/metadef/inc/graph/graph_util.h new file mode 100644 index 00000000..c39ecbc1 --- /dev/null +++ b/metadef/inc/graph/graph_util.h @@ -0,0 +1,134 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_GRAPH_UTIL_H_ +#define INC_GRAPH_GRAPH_UTIL_H_ + +#include + +#include "proto/om.pb.h" + +namespace ge { +using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; +bool HasOpAttr(const OpDef *opdef, std::string attr_name); +bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); + +static const char OP_TYPE_DATA[] = "Data"; +static const char OP_TYPE_INPUT[] = "Input"; +static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; +static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; +static const char OP_TYPE_ANN_DATA[] = "AnnData"; +} // namespace ge + +#if !defined(__ANDROID__) && !defined(ANDROID) +#include "toolchain/slog.h" +const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; +#else +#include +#include +const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; +#endif + +#ifdef _MSC_VER +#define FUNC_NAME __FUNCTION__ +#else +#define FUNC_NAME __PRETTY_FUNCTION__ +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ + dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ + dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ + dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#else +#define D_GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) +#else + +#define GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const domi::graphStatus _status = (expr); \ + if (_status != domi::GRAPH_SUCCESS) { \ + return _status; \ + } \ + } while (0) + +#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + }; + +#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ::domi::graphStatus _status = (expr); \ + if (_status) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/metadef/inc/graph/model.h b/metadef/inc/graph/model.h new file mode 100644 index 00000000..9beb5578 --- /dev/null +++ b/metadef/inc/graph/model.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_MODEL_H_ +#define INC_GRAPH_MODEL_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/graph.h" + +namespace ge { +using std::map; +using std::string; +using std::vector; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { + public: + Model(); + + ~Model() = default; + + Model(const string &name, const string &custom_version); + + string GetName() const; + void SetName(const string &name); + + uint32_t GetVersion() const; + + void SetVersion(uint32_t version) { version_ = version; } + + std::string GetPlatformVersion() const; + + void SetPlatformVersion(string version) { platform_version_ = version; } + + Graph GetGraph() const; + + void SetGraph(const Graph &graph); + + void SetAttr(const ProtoAttrMapHelper &attrs); + + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + graphStatus Save(Buffer &buffer, bool is_dump = false) const; + + graphStatus SaveToFile(const string& file_name) const; + // Model will be rewrite + static graphStatus Load(const uint8_t *data, size_t len, Model &model); + graphStatus Load(ge::proto::ModelDef &model_def); + graphStatus LoadFromFile(const string& file_name); + + bool IsValid() const; + + protected: + ConstProtoAttrMapHelper GetAttrMap() const override; + ProtoAttrMapHelper MutableAttrMap() override; + + private: + void Init(); + ProtoAttrMapHelper attrs_; + friend class ModelSerializeImp; + friend class GraphDebugImp; + friend class OnnxUtils; + friend class ModelHelper; + friend class ModelBuilder; + string name_; + uint32_t version_; + std::string platform_version_{""}; + Graph graph_; +}; +} // namespace ge +using ModelPtr = std::shared_ptr; + +#endif // INC_GRAPH_MODEL_H_ diff --git a/metadef/inc/graph/model_serialize.h b/metadef/inc/graph/model_serialize.h new file mode 100644 index 00000000..a23039c9 --- /dev/null +++ b/metadef/inc/graph/model_serialize.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ +#define INC_GRAPH_MODEL_SERIALIZE_H_ + +#include +#include +#include "graph/buffer.h" +#include "graph/compute_graph.h" +#include "graph/model.h" + +namespace ge { +class ModelSerialize { + public: + Buffer SerializeModel(const Model &model, bool is_dump = false); + + Model UnserializeModel(const uint8_t *data, size_t len); + Model UnserializeModel(ge::proto::ModelDef &model_def); + + Buffer SerializeGraph(const ComputeGraphPtr &graph); + + ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len); + + Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc); + OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len); + + size_t GetSerializeModelSize(const Model &model); + + private: + static std::map &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc); + + static const std::map &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc); + + friend class ModelSerializeImp; + friend class GraphDebugImp; +}; +} // namespace ge +#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/metadef/inc/graph/node.h b/metadef/inc/graph/node.h new file mode 100644 index 00000000..467e79d7 --- /dev/null +++ b/metadef/inc/graph/node.h @@ -0,0 +1,214 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_NODE_H_ +#define INC_GRAPH_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/ge_attr_value.h" +#include "utils/attr_utils.h" + +#include "graph/op_desc.h" +#include "graph/range_vistor.h" + +namespace ge { +class ComputeGraph; + +using ComputeGraphPtr = std::shared_ptr; + +class Node; + +using NodePtr = std::shared_ptr; +using ConstNodePtr = std::shared_ptr; +using NodeRef = std::weak_ptr; + +class Anchor; + +using AnchorPtr = std::shared_ptr; + +class InDataAnchor; + +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; + +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; + +using ControlAnchorPtr = std::shared_ptr; + +class InControlAnchor; + +using InControlAnchorPtr = std::shared_ptr; + +class OutControlAnchor; + +using OutControlAnchorPtr = std::shared_ptr; + +using OpDescPtr = std::shared_ptr; + +using ConstNode = const Node; + +typedef std::vector> kFusionDataFlowVec_t; + +// Node is a component of ComputeGraph +class Node : public std::enable_shared_from_this { + friend class ComputeGraph; + friend class ModelSerializeImp; + + public: + template + using Vistor = RangeVistor>; + ~Node(); + Node(const Node &) = delete; + Node &operator=(const Node &) = delete; + bool operator==(const Node &r_node) const; + + protected: + Node() = default; + Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); + + public: + graphStatus Init(); + + std::string GetName() const; + std::string GetType() const; + + ComputeGraphPtr GetOwnerComputeGraph() const; + graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); + graphStatus SetAnyOwnerComputeGraph(const ComputeGraphPtr &graph); + + Vistor GetAllInDataAnchors() const; + Vistor GetAllOutDataAnchors() const; + uint32_t GetAllInDataAnchorsSize() const; + uint32_t GetAllOutDataAnchorsSize() const; + Vistor GetAllOutAnchors() const; + Vistor GetAllInAnchors() const; + InDataAnchorPtr GetInDataAnchor(int idx) const; + OutDataAnchorPtr GetOutDataAnchor(int idx) const; + InControlAnchorPtr GetInControlAnchor() const; + OutControlAnchorPtr GetOutControlAnchor() const; + Vistor GetInNodes() const; + Vistor GetOutNodes() const; + AnchorPtr GetInAnchor(int idx) const; + AnchorPtr GetOutAnchor(int idx) const; + + bool IsAllInNodesSeen(std::unordered_set &nodes_seen) const; + + // All in Data nodes + Vistor GetInDataNodes() const; + // All in Control nodes + Vistor GetInControlNodes() const; + // All in Data nodes and Control nodes + Vistor GetInAllNodes() const; + + // All out Data nodes + Vistor GetOutDataNodes() const; + uint32_t GetOutDataNodesSize() const; + // All out Control nodes + Vistor GetOutControlNodes() const; + // All out Data nodes and Control nodes + Vistor GetOutAllNodes() const; + + // Get all in data nodes and its out-anchor + Vistor> GetInDataNodesAndAnchors() const; + + // Get all out data nodes and its in-anchor + Vistor> GetOutDataNodesAndAnchors() const; + + graphStatus InferShapeAndType() const; + graphStatus Verify() const; + + graphStatus InferOriginFormat() const; + + OpDescPtr GetOpDesc() const; + + graphStatus UpdateOpDesc(const OpDescPtr &op); + + graphStatus AddLinkFrom(const NodePtr &input_node); + + graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); + + graphStatus AddLinkFrom(const string &name, NodePtr input_node); + + graphStatus AddLinkFromForParse(const NodePtr &input_node); + + void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } + + void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } + + const std::vector &GetSendEventIdList() const { return send_event_id_list_; } + + const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } + void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { + fusion_input_list = fusion_input_dataflow_list_; + } + + void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { + fusion_output_list = fusion_output_dataflow_list_; + } + + void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { + fusion_input_dataflow_list_ = fusion_input_list; + } + + void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { + fusion_output_dataflow_list_ = fusion_output_list; + } + + bool GetHostNode() const { return host_node_; } + void SetHostNode(bool is_host) { host_node_ = is_host; } + + void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } + + NodePtr GetOrigNode() { return orig_node_; } + + private: + bool NodeMembersAreEqual(const Node &r_node) const; + bool NodeAttrsAreEqual(const Node &r_node) const; + bool NodeInConnectsAreEqual(const Node &r_node) const; + bool NodeOutConnectsAreEqual(const Node &r_node) const; + bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; + OpDescPtr op_; + std::weak_ptr owner_graph_; + vector in_data_anchors_; + vector out_data_anchors_; + InControlAnchorPtr in_control_anchor_; + OutControlAnchorPtr out_control_anchor_; + map attrs_; // lint !e1073 + bool has_init_{false}; + bool host_node_{false}; + bool anchor_status_updated_{false}; + std::vector send_event_id_list_; + std::vector recv_event_id_list_; + + kFusionDataFlowVec_t fusion_input_dataflow_list_; + kFusionDataFlowVec_t fusion_output_dataflow_list_; + + NodePtr orig_node_; + friend class NodeUtils; + friend class OnnxUtils; + friend class TuningUtils; +}; +} // namespace ge + +#endif // INC_GRAPH_NODE_H_ diff --git a/metadef/inc/graph/op_desc.h b/metadef/inc/graph/op_desc.h new file mode 100644 index 00000000..a86adf43 --- /dev/null +++ b/metadef/inc/graph/op_desc.h @@ -0,0 +1,340 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OP_DESC_H_ +#define INC_GRAPH_OP_DESC_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/range_vistor.h" + +#define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) +#define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) +namespace ge { +using std::map; +using std::pair; +using std::shared_ptr; +using std::string; +using std::vector; + +class Operator; +class GeTensorDesc; + +using GeTensorDescPtr = shared_ptr; +using ConstGeTensorDescPtr = shared_ptr; + +class OpDesc; + +using OpDescPtr = shared_ptr; +using ConstOpDescPtr = shared_ptr; + +class GeAttrValue; + +using ConstOpDesc = const OpDesc; + +enum SubgraphType { + kStatic, + kDynamic, + kSubgraphTypeEnd +}; + +class OpDesc : public std::enable_shared_from_this, public AttrHolder { + public: + template + using Vistor = RangeVistor>; + + friend class GraphBuilderImpl; + + friend class OperatorImpl; + + OpDesc(const string &name, const string &type); + + OpDesc(); + + ~OpDesc(); + + bool operator==(const OpDesc &r_op_desc) const; + + string GetName() const; + + void SetName(const string &name); + + string GetType() const; + + void SetType(const string &type); + + graphStatus AddInputDesc(const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); + + graphStatus AddInputDescForward(const string &name, const unsigned int num); + + graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); + + graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); + + graphStatus AddOutputDescForward(const string &name, const unsigned int num); + + graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); + + graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); + + bool InputIsSet(const string &name) const; + + GeTensorDesc GetInputDesc(uint32_t index) const; + + GeTensorDesc GetInputDesc(const string &name) const; + + Vistor GetAllInputNames() const; + + GeTensorDescPtr MutableInputDesc(uint32_t index) const; + + GeTensorDescPtr MutableInputDesc(const string &name) const; + + Vistor GetAllInputsDesc() const; + + Vistor GetAllInputsDescPtr() const; + + size_t GetInputsSize() const; + + size_t GetAllInputsSize() const; + + graphStatus AddOutputDesc(const GeTensorDesc &output_desc); + + graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); + + graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); + + GeTensorDesc GetOutputDesc(uint32_t index) const; + + GeTensorDesc GetOutputDesc(const string &name) const; + + GeTensorDescPtr MutableOutputDesc(uint32_t index) const; + + GeTensorDescPtr MutableOutputDesc(const string &name) const; + + uint32_t GetAllOutputsDescSize() const; + + Vistor GetAllOutputsDesc() const; + + Vistor GetAllOutputsDescPtr() const; + + size_t GetOutputsSize() const; + + ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; + + graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); + + graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); + + graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); + + bool IsOptionalInput(const string &name) const; + + bool IsOptionalInput(uint32_t index) const; + + std::map GetAllInputName() const; + + std::map GetAllOutputName(); + + std::map& MutableAllInputName(); + + std::map& MutableAllOutputName(); + + bool UpdateInputName(std::map inputNameIdx); + + bool UpdateOutputName(std::map outputNameIdx); + + void AddInferFunc(const std::function &func); + + std::function GetInferFunc() const; + + graphStatus InferShapeAndType(); + + void AddInferFormatFunc(const std::function &func); + + std::function GetInferFormatFunc() const; + + graphStatus DefaultInferFormat(); + + std::function GetVerifyFunc() const; + + void AddVerifierFunc(const std::function &func); + + graphStatus CallInferFormatFunc(Operator &op); + + graphStatus OpVerify(); + + graphStatus CommonVerify() const; + + graphStatus AddRegisterInputName(const string &name); + + graphStatus AddRegisterOutputName(const string &name); + + vector GetRegisterInputName() const; + + vector GetRegisterOutputName() const; + + using AttrHolder::AddRequiredAttr; + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + void SetId(int64_t id); + int64_t GetId() const; + void SetStreamId(int64_t stream_id); + int64_t GetStreamId() const; + void SetInputName(const vector &input_name); + vector GetInputName() const; + void SetSrcName(const vector &src_name); + vector GetSrcName() const; + void SetSrcIndex(const vector &src_index); + vector GetSrcIndex() const; + void SetInputOffset(const vector &input); + vector GetInputOffset() const; + void SetOutputOffset(const vector &input); + vector GetOutputOffset() const; + void SetDstName(const vector &dst_name); + vector GetDstName() const; + void SetDstIndex(const vector &dst_index); + vector GetDstIndex() const; + void SetWorkspace(const vector &workspace); + vector GetWorkspace() const; + void SetWorkspaceBytes(const vector &workspace_bytes); + vector GetWorkspaceBytes() const; + void SetIsInputConst(const vector &is_input_const); + vector GetIsInputConst() const; + + void SetOpInferDepends(const vector &depend_names); + vector GetOpInferDepends() const; + + string GetInputNameByIndex(uint32_t index) const; + string GetValidInputNameByIndex(uint32_t index) const; + int GetValidInputIndexByName(const string &name) const; + int GetInputIndexByName(const string &name) const; + + string GetOutputNameByIndex(uint32_t index) const; + + int GetOutputIndexByName(const string &name) const; + + graphStatus RestoreInputNameIdx(const string &name, const int &index); + + graphStatus RestoreOutputNameIdx(const string &name, const int &index); + + graphStatus CallInferFunc(Operator &op); + + void SetOpKernelLibName(const std::string &name); + + std::string GetOpKernelLibName() const; + + void SetOpEngineName(const std::string &name); + + std::string GetOpEngineName() const; + + void RegisterSubgraphIrName(const std::string &name, SubgraphType type); + const std::map &GetSubgraphIrNames() const; + SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; + + graphStatus AddSubgraphName(const std::string &name); + const std::map &GetSubgraphNameIndexes() const; + + std::string GetSubgraphInstanceName(uint32_t index) const; + const std::vector &GetSubgraphInstanceNames() const; + /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, + /// because this kind of functions will only append a new subgraph instance name + /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. + /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. + /// \param index + /// \param name + /// \return + graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); + void RemoveSubgraphInstanceName(const std::string &name); + + graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; + + graphStatus InferDataSlice(); + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); + bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; + bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; + bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; + + GeIrProtoHelper op_def_; + std::vector subgraph_instance_names_; + + // subgraph names to index, for a `if` operator: + // then_branch: 0 + // else_branch: 1 + // or for a `case` node: + // branches0: 0 + // branches1: 1 + // branches2: 2 + std::map subgraph_names_to_index_; + + // subgraph ir names to type, for a `if` operator: + // then_branch: static + // else_branch: static + // or for a `case` op: + // branches: dynamic + std::map subgraph_ir_names_to_type_; + + vector inputs_desc_{}; + map input_name_idx_{}; + vector register_input_name_{}; + std::unordered_set optional_input_names_{}; + vector outputs_desc_{}; + map output_name_idx_{}; + vector register_output_name_{}; + std::function infer_func_ = nullptr; + std::function infer_format_func_ = nullptr; + std::function verifier_func_ = nullptr; + std::function infer_data_slice_func_ = nullptr; + string op_kernel_lib_name_; + string engine_name_; + friend class OpDescUtils; + friend class ModelSerializeImp; + friend class AttrUtils; + friend class GeAttrValueImp; + friend class OnnxUtils; +}; +} // namespace ge +#endif // INC_GRAPH_OP_DESC_H_ diff --git a/metadef/inc/graph/op_kernel_bin.h b/metadef/inc/graph/op_kernel_bin.h new file mode 100644 index 00000000..61a52730 --- /dev/null +++ b/metadef/inc/graph/op_kernel_bin.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ +#define INC_GRAPH_OP_KERNEL_BIN_H_ + +#include +#include +#include +#include + +namespace ge { +class OpKernelBin { + public: + OpKernelBin(std::string name, std::vector &&data) : name_(std::move(name)), data_(std::move(data)) {} + + ~OpKernelBin() = default; + + const std::string &GetName() const { return name_; } + const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); } + size_t GetBinDataSize() const { return data_.size(); } + OpKernelBin(const OpKernelBin &) = delete; + const OpKernelBin &operator=(const OpKernelBin &) = delete; + + private: + std::string name_; + std::vector data_; +}; + +using OpKernelBinPtr = std::shared_ptr; +const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; +const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; +} // namespace ge + +#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/metadef/inc/graph/operator_factory_impl.h b/metadef/inc/graph/operator_factory_impl.h new file mode 100644 index 00000000..7febb8b6 --- /dev/null +++ b/metadef/inc/graph/operator_factory_impl.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ +#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ + +#include +#include +#include +#include +#include "graph/operator_factory.h" +#include "register/infer_data_slice_registry.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { + public: + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static bool IsExistOp(const string &operator_type); + + static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); + + static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); + + static VerifyFunc GetVerifyFunc(const std::string &operator_type); + + static InferDataSliceFunc GetInferDataSliceFunc(const std::string &operator_type); + + static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); + + static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreatorV2 const &op_creator); + + static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); + + static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); + + static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); + + static graphStatus RegisterInferDataSliceFunc(const std::string &operator_type, + InferDataSliceFunc const infer_data_slice_func); + + static shared_ptr> operator_creators_; + static shared_ptr> operator_creators_v2_; + static shared_ptr> operator_infershape_funcs_; + static shared_ptr> operator_inferformat_funcs_; + static shared_ptr> operator_verify_funcs_; + static shared_ptr> operator_infer_data_slice_funcs_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ diff --git a/metadef/inc/graph/opsproto_manager.h b/metadef/inc/graph/opsproto_manager.h new file mode 100644 index 00000000..932de91a --- /dev/null +++ b/metadef/inc/graph/opsproto_manager.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ +#define INC_GRAPH_OPSPROTO_MANAGER_H_ + +#include +#include +#include +#include +#include + +namespace ge { +class OpsProtoManager { + public: + static OpsProtoManager *Instance(); + + bool Initialize(const std::map &options); + void Finalize(); + + private: + void LoadOpsProtoPluginSo(std::string &path); + + std::string pluginPath_; + std::vector handles_; + bool is_init_ = false; + std::mutex mutex_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/metadef/inc/graph/range_vistor.h b/metadef/inc/graph/range_vistor.h new file mode 100644 index 00000000..50c02cfc --- /dev/null +++ b/metadef/inc/graph/range_vistor.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_RANGE_VISTOR_H_ +#define INC_GRAPH_RANGE_VISTOR_H_ + +#include + +template +class RangeVistor { + public: + using Iterator = typename std::vector::iterator; + using ConstIterator = typename std::vector::const_iterator; + + RangeVistor(O owner, const std::vector &vs) : owner_(owner), elements_(vs) {} + + ~RangeVistor() {} + + Iterator begin() { return elements_.begin(); } + + Iterator end() { return elements_.end(); } + + ConstIterator begin() const { return elements_.begin(); } + + ConstIterator end() const { return elements_.end(); } + + std::size_t size() const { return elements_.size(); } + + bool empty() const { return elements_.empty(); } + + E &at(std::size_t index) { return elements_.at(index); } + + const E &at(std::size_t index) const { return elements_.at(index); } + + private: + O owner_; + std::vector elements_; +}; + +#endif // INC_GRAPH_RANGE_VISTOR_H_ diff --git a/metadef/inc/graph/ref_relation.h b/metadef/inc/graph/ref_relation.h new file mode 100644 index 00000000..afd4044e --- /dev/null +++ b/metadef/inc/graph/ref_relation.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_REF_RELATION_H_ +#define COMMON_GRAPH_REF_RELATION_H_ + +#include +#include +#include +#include + +#include "graph/compute_graph.h" +#include "graph/types.h" +#include "graph/ge_error_codes.h" +#include "node.h" + +namespace ge { +enum InOutFlag { + NODE_IN = 0, // input flag + NODE_OUT = 1, // output flag +}; + +struct RefCell { + std::string node_name; + ge::NodePtr node = nullptr; + InOutFlag in_out = NODE_IN; + int in_out_idx = 0; + + bool operator == (const RefCell &c) const { + return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; + } + + RefCell() = default; + RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { + node_name = name; + node = node_ptr; + in_out = in_out_flag; + in_out_idx = idx; + }; + ~RefCell() = default; +}; + +struct RefCellHash{ + size_t operator () (const RefCell &c) const { + unsigned long number = static_cast(reinterpret_cast(c.node.get())); + string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + + std::to_string(number); + return std::hash()(tmp); + } +}; + +class RefRelations { +public: + graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear(); + + RefRelations(); + ~RefRelations() = default; +public: + class Impl; + std::shared_ptr impl_ = nullptr; +}; + +} // namespace ge +#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/metadef/inc/graph/runtime_inference_context.h b/metadef/inc/graph/runtime_inference_context.h new file mode 100644 index 00000000..0e7b092e --- /dev/null +++ b/metadef/inc/graph/runtime_inference_context.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ +#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ + +#include +#include +#include +#include +#include "external/graph/ge_error_codes.h" +#include "external/graph/tensor.h" +#include "ge_attr_value.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { + public: + static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); + static graphStatus CreateContext(const std::string &context_id); + static void DestroyContext(const std::string &context_id); + + graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); + graphStatus GetTensor(int64_t node_id, int output_id, GeTensorPtr &tensor); + graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); + + private: + std::map> tensors_; + std::map> ge_tensors_; + std::mutex mu_; + + static std::map> contexts_; + static std::mutex ctx_mu_; +}; +} // namespace ge + +#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/metadef/inc/graph/shape_refiner.h b/metadef/inc/graph/shape_refiner.h new file mode 100644 index 00000000..bc6b437a --- /dev/null +++ b/metadef/inc/graph/shape_refiner.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_SHAPE_REFINER_H_ +#define INC_GRAPH_SHAPE_REFINER_H_ + +#include +#include "external/graph/inference_context.h" + +#include "external/graph/ge_error_codes.h" +#include "graph/node.h" + +namespace ge { +// ShapeRefiner performs shape inference for compute graphs +class ShapeRefiner { + public: + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node); + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); + static graphStatus InferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, bool before_subgraph); + static graphStatus InferShapeAndTypeForRunning(const NodePtr &node, bool before_subgraph); + static void ClearContextMap(); + + private: + static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); +}; +} // namespace ge +#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/metadef/inc/graph/tuning_utils.h b/metadef/inc/graph/tuning_utils.h new file mode 100644 index 00000000..fe07dde1 --- /dev/null +++ b/metadef/inc/graph/tuning_utils.h @@ -0,0 +1,133 @@ +#ifndef MAIN_TUNING_UTILS_H +#define MAIN_TUNING_UTILS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "utils/attr_utils.h" +#include "utils/node_utils.h" +#include "external/ge/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +namespace ge { +// Configure build mode, default value is "normal" +const char *const BUILD_MODE = "ge.buildMode"; +const char *const BUILD_STEP = "ge.buildStep"; +// Configure tuning path +const char *const TUNING_PATH = "ge.tuningPath"; +// for interface: aclgrphBuildModel +const std::set ir_builder_supported_options_for_lx_fusion = { + BUILD_MODE, + BUILD_STEP, + TUNING_PATH +}; + +// Build model +const char *const BUILD_MODE_NORMAL = "normal"; +const char *const BUILD_MODE_TUNING = "tuning"; +const char *const BUILD_MODE_BASELINE = "baseline"; +const std::set build_mode_options = { + BUILD_MODE_NORMAL, + BUILD_MODE_TUNING, + BUILD_MODE_BASELINE +}; + +// Build step +const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; +const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; +const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; +const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; +const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; +const std::set build_step_options = { + BUILD_STEP_BEFORE_UB_MATCH, + BUILD_STEP_AFTER_UB_MATCH, + BUILD_STEP_AFTER_BUILDER, + BUILD_STEP_AFTER_BUILDER_SUB, + BUILD_STEP_AFTER_MERGE +}; + +using SubgraphCreateOutNode = std::unordered_map; +using NodetoNodeMap = std::unordered_map; +using NodeVec = std::vector; +using NodeNametoNodeNameMap = std::unordered_map; +using NodetoNodeNameMap = std::unordered_map; +class TuningUtils { + public: + TuningUtils() = default; + ~TuningUtils() = default; + // Dump all the subgraphs and modify + // the subgraphs in them to be executable subgraphs if exe_flag is true + // `tuning_path` means path to save the graphs + static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, + std::vector non_tuning_subgraphs = {}, + bool exe_flag = false, + const std::string &path = "", + const std::string &user_path = ""); + // Recovery `graph` from graph dump files configured in options + static graphStatus ConvertFileToGraph(const map &options, ge::Graph &graph); + + private: + // part 1 + struct HelpInfo { + int64_t index; + bool exe_flag; + bool is_tuning_graph; + const std::string &path; + const std::string &user_path; + }; + static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, + const HelpInfo& help_info); + static graphStatus HandlePld(NodePtr &node); + static graphStatus HandleEnd(NodePtr &node); + static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); + static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); + static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); + static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); + static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); + static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, + bool is_tuning_graph, std::string path); + + static SubgraphCreateOutNode create_output_; + // part 2 + static graphStatus MergeAllSubGraph(std::vector &graphs, + ComputeGraphPtr &graph); + static graphStatus MergeSubGraph(ComputeGraphPtr &graph); + // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 + static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); + static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, + NodePtr &out_node, + AnchorPtr &dest_in_anchor, + AnchorPtr &src_out_anchor); + static graphStatus HandleContinuousInputNodeNextData(NodePtr &node); + static NodeNametoNodeNameMap data_2_netoutput_; + static NodetoNodeNameMap data_node_2_netoutput_; + static NodetoNodeMap data_node_2_netoutput_node_; + static NodeVec netoutput_nodes_; + static NodeVec merged_graph_nodes_; + static std::mutex mutex_; + // for debug + static std::string PrintCheckLog(); + static std::string GetNodeNameByAnchor(const Anchor *anchor); +}; +} +#endif //MAIN_TUNING_UTILS_H diff --git a/metadef/inc/graph/usr_types.h b/metadef/inc/graph/usr_types.h new file mode 100644 index 00000000..7da9d49b --- /dev/null +++ b/metadef/inc/graph/usr_types.h @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_USR_TYPES_H_ +#define INC_GRAPH_USR_TYPES_H_ + +#include +#include +#include +namespace ge { +#define USR_TYPE_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + type *mutable_##name() { return &name; } + +#define USR_TYPE_HAS_DEC(type, name) \ + inline void set_##name(const type &value) { name = value; } \ + \ + private: \ + bool has_mutable_##name{false}; \ + \ + public: \ + bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ + type *mutable_##name() { \ + has_mutable_##name = true; \ + return &name; \ + } + +#define USR_TYPE_BYTES_DEC(name) \ + inline void clear_##name() { name.clear(); } \ + inline void set_##name(const void *value, size_t size) { \ + name.assign(reinterpret_cast(const_cast(value)), \ + reinterpret_cast(const_cast(value)) + size); \ + } + +enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 }; +enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 }; +enum UsrQuantizeAlgorithm { + USR_NON_OFFSET_ALGO = 0, + USR_HALF_OFFSET_ALGO = 1, + USR_ALL_OFFSET_ALGO = 2, +}; + +struct UsrQuantizeFactor { + public: + // QuantizeScaleMode scale_mode; + UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE}; + std::vector scale_value; + int64_t scale_offset{0}; + std::vector offset_data_value; + int64_t offset_data_offset{0}; + std::vector offset_weight_value; + int64_t offset_weight_offset{0}; + std::vector offset_pad_value; + int64_t offset_pad_offset{0}; + + USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode); + USR_TYPE_BYTES_DEC(scale_value); + + USR_TYPE_DEC(int64_t, scale_offset); + USR_TYPE_BYTES_DEC(offset_data_value); + USR_TYPE_DEC(int64_t, offset_data_offset); + + USR_TYPE_BYTES_DEC(offset_weight_value); + USR_TYPE_DEC(int64_t, offset_weight_offset); + USR_TYPE_BYTES_DEC(offset_pad_value); + USR_TYPE_DEC(int64_t, offset_pad_offset); +}; + +static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) { + return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 || + factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0; +} + +struct UsrQuantizeCalcFactor { + public: + std::vector offsetw; + int64_t offsetw_offset{0}; + std::vector offsetd; + int64_t offsetd_offset{0}; + std::vector scalereq; + int64_t scaledreq_offset{0}; + std::vector offsetdnext; + int64_t offsetdnext_offset{0}; + + USR_TYPE_BYTES_DEC(offsetw); + USR_TYPE_DEC(int64_t, offsetw_offset); + USR_TYPE_BYTES_DEC(offsetd); + USR_TYPE_DEC(int64_t, offsetd_offset); + USR_TYPE_BYTES_DEC(scalereq); + USR_TYPE_DEC(int64_t, scaledreq_offset); + USR_TYPE_BYTES_DEC(offsetdnext); + USR_TYPE_DEC(int64_t, offsetdnext_offset); +}; + +static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) { + return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 || + factor.offsetdnext.size() > 0; +} + +struct UsrQuantizeFactorParams { + UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO}; + UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE}; + UsrQuantizeFactor quantize_param; + UsrQuantizeFactor dequantize_param; + UsrQuantizeFactor requantize_param; + UsrQuantizeCalcFactor quantizecalc_param; + USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo); + USR_TYPE_DEC(UsrQuantizeScaleType, scale_type); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param); + USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param); +}; + +#undef USR_TYPE_DEC +#undef USR_TYPE_HAS_DEC +#undef USR_TYPE_BYTES_DEC +} // namespace ge + +#endif // INC_GRAPH_USR_TYPES_H_ + diff --git a/metadef/inc/graph/utils/anchor_utils.h b/metadef/inc/graph/utils/anchor_utils.h new file mode 100644 index 00000000..f3f71293 --- /dev/null +++ b/metadef/inc/graph/utils/anchor_utils.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ +#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ + +#include "graph/anchor.h" +#include "graph/node.h" + +namespace ge { +class AnchorUtils { + public: + // Get anchor format + static Format GetFormat(const DataAnchorPtr &dataAnchor); + + // Set anchor format + static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat); + + // Get anchor status + static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor); + + // Set anchor status + static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus); + + static bool HasControlEdge(const AnchorPtr &anchor); + + static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static int GetIdx(const AnchorPtr &anchor); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/metadef/inc/graph/utils/attr_utils.h b/metadef/inc/graph/utils/attr_utils.h new file mode 100644 index 00000000..1e273f38 --- /dev/null +++ b/metadef/inc/graph/utils/attr_utils.h @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ +#define INC_GRAPH_UTILS_ATTR_UTILS_H_ + +#include +#include +#include +#include "graph/detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/types.h" + +namespace ge { +class OpDesc; +using OpDescPtr = std::shared_ptr; +using ConstOpDescPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { + public: + class ConstAttrHolderAdapter; + class AttrHolderAdapter; + // Set + static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); + + static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); + + static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); + static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); + static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); + static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); + static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, + std::initializer_list &&value); + static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); + static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); + static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); + static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, + const vector &value); + static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); + + // Get + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); + static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); + static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); + static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); + static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); + static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); + static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); + static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); + static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); + static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, + vector &value); + static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + // Value will be moved + static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); + static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); + // Value will be moved + static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, + vector &listBuffer); + static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &listBuffer); + + static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector> &value); + static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector> &value); + + static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector &value); + static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector &value); + + static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); + static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); + + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); + + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); + + static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); + + class AttrHolderAdapter { + public: + AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} + ~AttrHolderAdapter() {} + template + AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} + AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + AttrHolder *operator->() { return obj_; } + AttrHolder *get() { return obj_; } + + AttrHolder *obj_; + }; + + class ConstAttrHolderAdapter { + public: + ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} + ~ConstAttrHolderAdapter() {} + template + ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} + ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + const AttrHolder *operator->() const { return obj_; } + const AttrHolder *get() const { return obj_; } + + private: + const AttrHolder *obj_; + }; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ diff --git a/metadef/inc/graph/utils/graph_utils.h b/metadef/inc/graph/utils/graph_utils.h new file mode 100644 index 00000000..508dc07b --- /dev/null +++ b/metadef/inc/graph/utils/graph_utils.h @@ -0,0 +1,812 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ +#define INC_GRAPH_UTILS_GRAPH_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "graph/node.h" +#include "graph/utils/anchor_utils.h" + +#define GE_DUMP(compute_graph, name) \ + do { \ + GraphUtils::DumpGEGraph(compute_graph, name); \ + GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ + uint64_t i = 0; \ + for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ + auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ + GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ + GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ + } \ + } while (0) + +#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + do { \ + DataType ret; \ + attr.GetValue(ret); \ + } while (0) + +#define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ + do { \ + if (value_type == VT_ENUM) { \ + REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + stream << ret; \ + } \ + } while (0) + +#define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ + do { \ + if (value_type == VT_ENUM) { \ + REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ + stream << "["; \ + for (int i = 0; i < ret.size(); i++) { \ + stream << ret[i]; \ + if (i + 1 != ret.size()) stream << ", "; \ + } \ + stream << "]"; \ + } \ + } while (0) + +#define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ + else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) + +#define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ + else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) + +#define PRINT_SHAPE(i_o, n, idx, stream) \ + do { \ + auto op = n->GetOpDesc(); \ + GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ + auto shape = td.GetShape().GetDims(); \ + stream << "["; \ + for (int i = 0; i < shape.size(); i++) { \ + stream << shape[i]; \ + if (i + 1 < shape.size()) stream << ", "; \ + } \ + stream << "]"; \ + } while (0) + +#define PRINT_ATTR_FUNC(stream) \ + [&](GeAttrValue attr) { \ + auto type = attr.GetValueType(); \ + PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ + PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ + PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ + else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ + else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ + else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ + else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ + }; + +namespace ge { +enum IOType { kIn, kOut }; + +struct NodeIndexIO { + NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) + : node_(std::move(node)), index_(index), io_type_(io_type) { + if (node_ != nullptr) { + value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); + } + } + NodeIndexIO(ge::NodePtr node, int index, IOType io_type) + : node_(std::move(node)), index_(static_cast(index)), io_type_(io_type) { + if (node_ != nullptr) { + value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); + } + } + ~NodeIndexIO() {} + + NodePtr node_ = nullptr; + uint32_t index_ = 0; + IOType io_type_ = kOut; + std::string value_; + + const std::string &ToString() const { return value_; } +}; + +class GraphUtils { + public: + static ComputeGraphPtr GetComputeGraph(const Graph &graph); + + static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); + + static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); + + static graphStatus RecoverGraphOperators(const Graph &graph); + + static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, + const Format &dst_format); + + static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + // check whether src is link to dst and then remove + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const InDataAnchorPtr &new_dst); + + static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const InControlAnchorPtr &new_dst); + + static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const NodePtr &new_node); + + static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); + + static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); + + static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, + const std::vector &vec_op_desc); + + /// + /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst + /// @param [in] src + /// @param [in] dsts + /// @param [in] insert_node + /// @param [in] input_index + /// @param [in] output_index + /// @return graphStatus + /// + static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); + + static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); + + static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); + + static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node); + + static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); + + static bool MatchDumpStr(const std::string &suffix); + + static void DumpGEGraph(const ge::ComputeGraphPtr &graph, + const std::string &suffix, + bool is_always_dump = false, + const std::string &user_graph_name = ""); + + static void DumpGEGrph(const ge::ComputeGraphPtr &graph, + const std::string &path, + const std::string &suffix); + + static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); + + static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); + + static void BreakConnect(const std::map &all_nodes_infos); + + static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); + + static void DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, + const std::string &path, const std::string &suffix); + + static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); + + static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); + + static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); + + static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// Isolating `node`, relinking data links from the in-anchor peer nodes to + /// the out-anchor peer nodes according to `io_map`, relinking control links + /// to ensure that input nodes of `node` are before out nodes + /// + /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer + /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, + /// unlink all links from `i` output anchor without any relinking. + /// + /// @param node + /// @param io_map + /// @return + /// + static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); + static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); + + /// + /// Isolate `node` which must be one input one output, equivalent to + /// `IsolateNode(node, {0})` + /// @param node + /// @return + /// + static graphStatus IsolateNodeOneIO(const NodePtr &node); + + /// + /// The data anchors replacing behavior is the same with + /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control + /// anchors with `new_node`'s. + /// @param new_node + /// @param old_node + /// @param inputs_map + /// @param outputs_map + /// @return + /// + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, std::initializer_list outputs_map); + + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, const std::vector &outputs_map); + + /// + /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. + /// Replace the `i` in/out data anchor on `old_node` with + /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. + /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in + /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain + /// on `old_node`. + /// @param new_node + /// @param old_node + /// @param inputs_map + /// @param outputs_map + /// @return + /// + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + std::initializer_list inputs_map, + std::initializer_list outputs_map); + + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, const std::vector &outputs_map); + + /// + /// Copy all in-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + /// + /// Copy all out-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return success: GRAPH_SUCESS + /// + static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); + + /// + /// Move all out-control edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return success: GRAPH_SUCESS + /// + static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); + + /// + /// Copy all in-data edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); + + static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); + + /// + /// Make a copy of ComputeGraph. + /// @param graph: original graph. + /// @param prefix: node name prefix of new graph. + /// @return ComputeGraphPtr + /// + static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, + std::vector &input_nodes, std::vector &output_nodes); + + /// + /// Copy tensor attribute to new node. + /// @param [in] dst_desc: cloned node. + /// @param [in] src_node: original node. + /// @return success: GRAPH_SUCESS + /// + static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); + + static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); + + /// + /// Get reference-mapping of all data_anchors in graph + /// @param [in] graph + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs + /// of the graph have UNKNOWN_SHAPE operators or not. + /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following + /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE + /// ROOT graph: A -----> B -----> C + /// K subgraph U + /// | + /// V + /// SUB graph: D --> E --> F + /// K K K + /// @param [in] graph + /// @return bool + /// + static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); + + static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); + + private: + /// + /// Get reference-mapping for in_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Get reference-mapping for out_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of Merge op + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle output of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Relink all edges for cloned ComputeGraph. + /// @param [in] node: original node. + /// @param [in] prefix: node name prefix of new node. + /// @param [in] all_nodes: all nodes in new graph. + /// @return success: GRAPH_SUCESS + /// + static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes); + + /// + /// Union ref-mapping + /// @param [in] exist_node_info1 + /// @param [in] exist_node_info2 + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @param [out] symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol); + + /// + /// Update symbol mapping with a new reference pair + /// @param [in] cur_node_info + /// @param [in] exist_node_info + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Check if out_data_anchor is reference of input + /// @param [in] out_data_anchor + /// @param [out] reuse_in_index + /// @return bool + /// + static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); +}; + +class ComputeGraphBuilder { + public: + ComputeGraphBuilder() : owner_graph_(nullptr) {} + ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; + ~ComputeGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind); + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return ComputeGraphBuilder + /// + virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); + + /// + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; + + /// @brief Get node with name + /// @param [in] name + /// @return NodePtr + /// + NodePtr GetNode(const std::string &name); + + /// @brief Get all nodes + /// @return std::vector + /// + std::vector GetAllNodes(); + + protected: + /// + /// @brief Build nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build data-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildDataLinks(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build ctrl-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); + + ComputeGraphPtr owner_graph_; + + // node_name -> node + std::map node_names_; + std::vector nodes_; + + // -> + std::vector, std::pair>> data_links_; + // src_node_name -> dst_node_name + std::vector> ctrl_links_; +}; + +class CompleteGraphBuilder : public ComputeGraphBuilder { + public: + explicit CompleteGraphBuilder(std::string name, bool retval_flag = true) + : name_(std::move(name)), parent_node_(nullptr), retval_flag_(retval_flag) {} + CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; + ~CompleteGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) override; + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// + /// @brief Set index_th input anchor for graph + /// @param [in] index + /// @param [in] node_names + /// @param [in] anchor_inds + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetInput(uint32_t index, const std::vector &node_names, + const std::vector &anchor_inds); + + /// + /// @brief Set index_th input of graph as useless + /// @param [in] index + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetUselessInput(uint32_t index); + + /// + /// @brief Add output anchor for graph + /// @param [in] owner_node_name + /// @param [in] anchor_ind + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); + + /// + /// @brief Add target for graph + /// @param [in] target_name + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &AddTarget(const std::string &target_name); + + /// + /// @brief Set parent-node of graph + /// @param [in] parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); + + /// + /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node + /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetInputMapping(const std::map &input_mapping); + + /// + /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind + /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + /// + CompleteGraphBuilder &SetOutputMapping(const std::map &output_mapping); + + /// + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// + /// @brief Add data nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void AddDataNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Add data node + /// @param [in] index + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Add RetVal nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void AddRetValNodes(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build target-nodes for graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Add NetOutput node + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void AddNetOutputNode(graphStatus &error_code, std::string &error_msg); + + /// + /// @brief Build NetOutput nodes with data & ctrl edges + /// @param [in] net_output_desc + /// @param [in] peer_out_anchors + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, + const std::vector &peer_out_anchors, + graphStatus &error_code, std::string &error_msg); + + /// + /// @brief process after build + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void PostProcess(graphStatus &error_code, std::string &error_msg); + + std::string name_; + NodePtr parent_node_; + bool retval_flag_; + std::map, std::vector>> graph_inputs_; + std::vector> graph_outputs_; + std::vector graph_targets_; + + // index_of_graph_input -> in_anchor_index_of_parent_node + std::map input_mapping_; + // index_of_graph_output -> out_anchor_index_of_parent_node + std::map output_mapping_; +}; + +class PartialGraphBuilder : public ComputeGraphBuilder { + public: + PartialGraphBuilder() = default; + PartialGraphBuilder(const PartialGraphBuilder &) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; + PartialGraphBuilder(const PartialGraphBuilder &&) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; + ~PartialGraphBuilder() = default; + + /// + /// @brief Add node to graph + /// @param [in] op_desc + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, + const std::string &dst_name, uint32_t in_anchor_ind) override; + + /// + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// + /// @brief Set owner graph + /// @param [in] graph + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); + + /// + /// @brief Add exist node + /// @param [in] node + /// @return PartialGraphBuilder + /// + PartialGraphBuilder &AddExistNode(const NodePtr &node); + + /// + /// @brief Build multi nodes with links + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + /// + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// + /// @brief Build exist nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + /// + void BuildExistNodes(graphStatus &error_code, std::string &error_msg); + + std::vector exist_nodes_; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/metadef/inc/graph/utils/node_adapter.h b/metadef/inc/graph/utils/node_adapter.h new file mode 100644 index 00000000..19d76543 --- /dev/null +++ b/metadef/inc/graph/utils/node_adapter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_NODE_ADAPTER_H_ +#define INC_GRAPH_UTILS_NODE_ADAPTER_H_ + +#include "graph/gnode.h" +#include "graph/node.h" + +namespace ge { +using NodePtr = std::shared_ptr; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodeAdapter { + public: + static GNode Node2GNode(const NodePtr &node); + static NodePtr GNode2Node(const GNode &node); + static GNodePtr Node2GNodePtr(const NodePtr &node); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_NODE_ADAPTER_H_ diff --git a/metadef/inc/graph/utils/node_utils.h b/metadef/inc/graph/utils/node_utils.h new file mode 100644 index 00000000..77555629 --- /dev/null +++ b/metadef/inc/graph/utils/node_utils.h @@ -0,0 +1,178 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ +#define INC_GRAPH_UTILS_NODE_UTILS_H_ + +#include +#include +#include +#include "external/graph/operator.h" +#include "graph/node.h" + +namespace ge { +// Op types of Const like Opps. +extern const std::set kConstOpTypes; +// Op types of If like Opps. +extern const std::set kIfOpTypes; +// Op types of While like Opps. +extern const std::set kWhileOpTypes; +// Op types of Case like Opps. +extern const std::set kCaseOpTypes; +// Op types of For like Opps. +extern const std::set kForOpTypes; + +class NodeUtils { + public: + static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); + static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id); + static graphStatus GetSendEventIdList(const NodePtr &node, std::vector &vec_send); + static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv); + + static graphStatus ClearSendInfo(); + static graphStatus ClearRecvInfo(); + + static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst); + + static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, + InControlAnchorPtr &in_control); + + static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); + static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr); + static graphStatus SetAllAnchorStatus(Node &node); + static bool IsAnchorStatusSet(const NodePtr &nodePtr); + static bool IsAnchorStatusSet(const Node &node); + + static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); + + static void UpdateIsInputConst(const NodePtr &nodePtr); + static void UpdateIsInputConst(Node &node); + static bool IsConst(const Node &node); + static void UnlinkAll(const Node &node); + static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); + + static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); + + static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); + static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); + + static bool IsInNodesEmpty(const Node &node); + static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); + static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); + static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); + static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); + // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; + // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, + // the out param "is_unknow" will be true too + static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); + + static std::string GetNodeType(const Node &node); + static std::string GetNodeType(const NodePtr &node); + + static std::vector GetAllSubgraphs(const Node &node); + static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); + static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); + + /// + /// Check if node is input of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphInput(const NodePtr &node); + + /// + /// Check if node is output of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphOutput(const NodePtr &node); + + /// + /// @brief Get subgraph original input node. + /// @param [in] node + /// @return Node + /// + static NodePtr GetParentInput(const Node &node); + static NodePtr GetParentInput(const NodePtr &node); + + /// + /// @brief Get is dynamic shape graph from node. + /// @param [in] node + /// @return bool + /// + static bool IsDynamicShape(const Node &node); + static bool IsDynamicShape(const NodePtr &node); + + /// + /// @brief Check is varying_input for while node + /// @param [in] node: Data node for subgraph + /// @return bool + /// + static bool IsWhileVaryingInput(const ge::NodePtr &node); + + /// + /// @brief Get subgraph input is constant. + /// @param [in] node + /// @param [out] string + /// @return bool + /// + static bool GetConstOpType(const NodePtr &node, std::string &type); + + /// + /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. + /// @param [in] node + /// @return return GRAPH_SUCCESS if remove successfully, other for failed. + /// + static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphDataNodesByIndex(const Node &node, int index); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphOutputNodes(const Node &node); + + static NodePtr GetInDataNodeByIndex(const Node &node, const int index); + + static vector> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); + + static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); + + static graphStatus GetInputConstData(const ConstNodePtr& node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); + + static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); + + /// + /// @brief Get node type in cross subgragh. + /// @param [in] node + /// @return type + /// + static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node); + + private: + static std::map> map_send_info_; + static std::map> map_recv_info_; +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/metadef/inc/graph/utils/op_desc_utils.h b/metadef/inc/graph/utils/op_desc_utils.h new file mode 100644 index 00000000..4589180d --- /dev/null +++ b/metadef/inc/graph/utils/op_desc_utils.h @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ +#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ + +#include +#include +#include +#include "graph/def_types.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" + +namespace ge { +class OpDesc; +using OpDescPtr = std::shared_ptr; + +class OpDescUtils { + public: + template + using Vistor = RangeVistor>; + + OpDescUtils() = default; + ~OpDescUtils() = default; + static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); + static bool HasQuantizeFactorParams(const OpDesc& op_desc); + static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); + static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); + static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams& quant); + static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); + + static vector GetConstInputNode(const ge::Node& node); + static vector GetInputData(const vector& input_nodes); + + static vector GetWeights(const ge::Node& node); + static vector GetWeights(const ge::ConstNodePtr& node); + static vector MutableWeights(const ge::Node& node); + static vector MutableWeights(const ge::NodePtr node); + static graphStatus SetWeights(ge::Node& node, const vector& weights); + static graphStatus SetWeights(ge::NodePtr node, const vector &weights); + static graphStatus SetWeights(ge::Node &node, const map &weights_map); + static graphStatus ClearWeights(ge::NodePtr node); + + static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); + static bool ClearInputDesc(const ge::NodePtr& node); + static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); + static bool ClearOutputDesc(const ge::NodePtr& node); + static vector GetConstInputs(const ge::Node& node); + static vector GetConstInputs(const ge::ConstNodePtr& node); + static size_t GetNonConstInputsSize(const ge::Node& node); + static size_t GetNonConstInputsSize(ge::ConstNodePtr node); + // Index: Indicates the index of all non const inputs + static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); + static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); + static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); + static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); + // Index: Indicates the index of all inputs + static bool IsNonConstInput(const ge::Node& node, size_t index = 0); + static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); + + static vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); + static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); + + static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); + static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); + static OpDescPtr GetOpDescFromOperator(const Operator& oprt); + + static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); + + static graphStatus SetSubgraphInstanceName(const std::string &subgraph_name, + const std::string &subgraph_instance_name, OpDescPtr &op_desc); + + private: + static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); + static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); + static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); + static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); +}; + +class OpDescBuilder { + public: + OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} + OpDescBuilder(const OpDescBuilder &) = delete; + OpDescBuilder &operator=(const OpDescBuilder &) = delete; + OpDescBuilder(const OpDescBuilder &&) = delete; + OpDescBuilder &operator=(const OpDescBuilder &&) = delete; + ~OpDescBuilder() = default; + + /// + /// @brief Add input + /// @param [in] name + /// @return OpDescBuilder + /// + OpDescBuilder& AddInput(const std::string &name); + + /// + /// @brief Add input + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddInput(const std::string &name, const GeTensorDesc &tensor); + + /// + /// @brief Add dynamic input + /// @param [in] name + /// @param [in] num + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicInput(const std::string &name, uint32_t num); + + /// + /// @brief Add dynamic input + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicInput(const std::string &name, uint32_t num, const GeTensorDesc &tensor); + + /// + /// @brief Add output + /// @param [in] name + /// @return OpDescBuilder + /// + OpDescBuilder& AddOutput(const std::string &name); + + /// + /// @brief Add output + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddOutput(const std::string &name, const GeTensorDesc &tensor); + + /// + /// @brief Add dynamic output + /// @param [in] name + /// @param [in] num + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicOutput(const std::string &name, uint32_t num); + + /// + /// @brief Add dynamic output + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicOutput(const std::string &name, uint32_t num, const GeTensorDesc &tensor); + + /// + /// @brief Build op_desc + /// @return OpDescPtr + /// + OpDescPtr Build(); + + private: + std::string name_; + std::string type_; + std::vector> inputs_; + std::vector> outputs_; +}; +} // namespace ge + +#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/metadef/inc/graph/utils/tensor_adapter.h b/metadef/inc/graph/utils/tensor_adapter.h new file mode 100644 index 00000000..7161ba3b --- /dev/null +++ b/metadef/inc/graph/utils/tensor_adapter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ +#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ + +#include +#include "graph/ge_tensor.h" +#include "graph/tensor.h" + +namespace ge { +using GeTensorPtr = std::shared_ptr; +using ConstGeTensorPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { + public: + static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc); + static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc); + static GeTensorPtr Tensor2GeTensor(const Tensor &tensor); + static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor); + + static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value + static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value + static const GeTensor AsGeTensor(const Tensor &tensor); // Share value + static GeTensor AsGeTensor(Tensor &tensor); // Share value + static const Tensor AsTensor(const GeTensor &tensor); // Share value + static Tensor AsTensor(GeTensor &tensor); // Share value +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/metadef/inc/graph/utils/tensor_utils.h b/metadef/inc/graph/utils/tensor_utils.h new file mode 100644 index 00000000..776933a8 --- /dev/null +++ b/metadef/inc/graph/utils/tensor_utils.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ +#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ + +#include +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" + +namespace ge { +class TensorUtils { + public: + static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); + static void SetSize(GeTensorDesc &tensorDesc, int64_t size); + static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); + static uint32_t GetWeightSize(const GeTensor &tensor); + static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); + static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base); + static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base); + static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size); + static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag); + static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag); + static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag); + static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag); + static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type); + static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type); + static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag); + static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag); + static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt); + static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt); + static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx); + static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx); + static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset); + static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset); + static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size); + static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size); + static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector &vec); + static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); + static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset); + static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset); + static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info); + static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info); + static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc); + static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info); + static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info); + static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc); + static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); + + /// + /// calculate tensor mem size. + /// @param shape tensor shape + /// @param format tensor format + /// @param data_type tensor data type + /// @param mem_size -1 means unknown shape,other means mem size + /// @return GRAPH_SUCCESS:success, other:failed + /// + static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); + static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); + static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/metadef/inc/graph/utils/type_utils.h b/metadef/inc/graph/utils/type_utils.h new file mode 100644 index 00000000..81382764 --- /dev/null +++ b/metadef/inc/graph/utils/type_utils.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ +#define INC_GRAPH_UTILS_TYPE_UTILS_H_ + +#include +#include +#include +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/usr_types.h" +#include "register/register_types.h" +#include "external/register/register_fmk_types.h" + +namespace ge { +class TypeUtils { + public: + static bool IsDataTypeValid(DataType dt); + static bool IsFormatValid(Format format); + static bool IsDataTypeValid(std::string dt); // for user json input + static bool IsFormatValid(std::string format); // for user json input + static bool IsInternalFormat(Format format); + + static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); + static std::string DataTypeToSerialString(DataType data_type); + static DataType SerialStringToDataType(const std::string &str); + static std::string FormatToSerialString(Format format); + static Format SerialStringToFormat(const std::string &str); + static Format DataFormatToFormat(const std::string &str); + static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); + static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type); + + static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); + static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); + + static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length); + static bool CheckUint64MulOverflow(uint64_t a, uint32_t b); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ diff --git a/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h new file mode 100644 index 00000000..bfed3549 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ +#include +#include + +namespace fe { +// add the op pattern +static const std::string TBE_PATTERN_INPUT_NODE = "InputData"; +static const std::string TBE_PATTERN_OP_TYPE_ANY = "OpTypeAny"; +static const std::string TBE_PATTERN_OUTPUT_NODE = "OutputData"; +static const std::string OP_PATTERN_ELEMWISE = "ElemWise"; +static const std::string OP_PATTERN_COMMONREDUCE = "CommReduce"; +static const std::string OP_PATTERN_SEGMENT = "Segment"; +static const std::string OP_PATTERN_MAXPOOL = "MaxPool"; +static const std::string OP_PATTERN_CONV = "Convolution"; +static const std::string OP_PATTERN_MATMUL = "Matmul"; +static const std::string OP_PATTERN_BNUPDATE = "bn_update"; +static const std::string OP_PATTERN_BNREDUCE = "bn_reduce"; +static const std::string OP_PATTERN_CONV_BACKPROP_INPUT = "Conv2d_backprop_input"; +static const std::string OP_PATTERN_DEPTHWISE_CONV = "DepthwiseConvolution"; +static const std::string OP_PATTERN_QUANT = "quant"; +static const std::string OP_PATTERN_DEQUANT = "dequant"; +static const std::string OP_PATTERN_REQUANT = "requant"; +static const std::string OP_PATTERN_POOL2D = "Pool2d"; +static const std::string OP_PATTERN_ANTIQUANT = "anti_quant"; +static const std::string OP_PATTERN_STRIDED_WRITE = "strided_write"; +static const std::string OP_PATTERN_STRIDED_READ = "strided_read"; +static const std::string OP_PATTERN_AIPP = "aipp"; +static const std::string OP_PATTERN_CONFUSION_TRANSPOSE = "confusiontranspose"; +static const std::string OP_PATTERN_DEQUANTS16 = "dequant_s16"; +static const std::string OP_PATTERN_REQUANTS16 = "requant_s16"; +static const std::string OP_PATTERN_READ_SELECT = "read_select"; +static const std::string OP_PATTERN_WRITE_SELECT = "write_select"; +static const std::string OP_PATTERN_BATCH_MATMUL = "BatchMatmul"; +static const std::string OP_PATTERN_CONV3D = "Conv3d"; + +static const std::vector OP_PATTERN_VEC{OP_PATTERN_ELEMWISE, + OP_PATTERN_COMMONREDUCE, + OP_PATTERN_SEGMENT, + OP_PATTERN_MAXPOOL, + OP_PATTERN_CONV, + OP_PATTERN_MATMUL, + OP_PATTERN_BNUPDATE, + OP_PATTERN_BNREDUCE, + OP_PATTERN_CONV_BACKPROP_INPUT, + OP_PATTERN_DEPTHWISE_CONV, + OP_PATTERN_QUANT, + OP_PATTERN_DEQUANT, + OP_PATTERN_REQUANT, + OP_PATTERN_POOL2D, + OP_PATTERN_ANTIQUANT, + OP_PATTERN_STRIDED_WRITE, + OP_PATTERN_STRIDED_READ, + OP_PATTERN_AIPP, + OP_PATTERN_CONFUSION_TRANSPOSE, + OP_PATTERN_DEQUANTS16, + OP_PATTERN_REQUANTS16, + OP_PATTERN_READ_SELECT, + OP_PATTERN_WRITE_SELECT, + OP_PATTERN_BATCH_MATMUL, + OP_PATTERN_CONV3D}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ diff --git a/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h new file mode 100644 index 00000000..4a860550 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ + +#include +#include +#include +#include +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h" +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +enum BufferFusionPassType { + BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, + BUILT_IN_VECTOR_CORE_BUFFER_FUSION_PASS, + CUSTOM_AI_CORE_BUFFER_FUSION_PASS, + CUSTOM_VECTOR_CORE_BUFFER_FUSION_PASS, + BUFFER_FUSION_PASS_TYPE_RESERVED +}; + +class BufferFusionPassBase { + public: + explicit BufferFusionPassBase(); + virtual ~BufferFusionPassBase(); + virtual std::vector DefinePatterns() = 0; + virtual Status GetFusionNodes(const BufferFusionMapping &mapping, vector &fusion_nodes); + std::vector GetMatchedNodes(const BufferFusionMapping &mapping); + std::vector GetMatchedNodesByDescName(const std::string &desc_name, const BufferFusionMapping &mapping); + ge::NodePtr GetMatchedHeadNode(const std::vector &matched_nodes); + + void SetName(const string &name) { name_ = name; } + + string GetName() { return name_; } + + private: + string name_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ diff --git a/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h new file mode 100644 index 00000000..92c6a70e --- /dev/null +++ b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_REGISTRY_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_REGISTRY_H_ +#include +#include +#include +#include +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" + +namespace fe { +class BufferFusionPassRegistry { + public: + using CreateFn = BufferFusionPassBase *(*)(); + ~BufferFusionPassRegistry(); + + static BufferFusionPassRegistry &GetInstance(); + + void RegisterPass(const BufferFusionPassType &pass_type, const std::string &pass_name, CreateFn create_fn); + + std::map GetCreateFnByType(const BufferFusionPassType &pass_type); + + private: + BufferFusionPassRegistry(); + class BufferFusionPassRegistryImpl; + std::unique_ptr impl_; +}; + +class BufferFusionPassRegistrar { + public: + BufferFusionPassRegistrar(const BufferFusionPassType &pass_type, const std::string &pass_name, + BufferFusionPassBase *(*create_fun)()); + ~BufferFusionPassRegistrar() {} +}; + +#define REGISTER_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class) \ + REGISTER_BUFFER_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class) + +#define REGISTER_BUFFER_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class) \ + REGISTER_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class) + +#define REGISTER_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class) \ + static ::fe::BufferFusionPassRegistrar register_buffer_fusion_pass##ctr __attribute__((unused)) = \ + ::fe::BufferFusionPassRegistrar( \ + pass_type, pass_name, []() -> ::fe::BufferFusionPassBase * { return new (std::nothrow) pass_class(); }) + +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_REGISTRY_H_ diff --git a/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h new file mode 100644 index 00000000..989702ea --- /dev/null +++ b/metadef/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ +#include +#include +#include +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" + +namespace fe { +static const int TBE_FUSION_OP_NUM_MAX = 5; +static const int TBE_PATTERN_NUM_MAX = 5; +static const int TBE_PATTERN_NUM_NONE = 0; +static const int TBE_PATTERN_NUM_DEFAULT = 1; +static const int TBE_OUTPUT_BRANCH_DEFAULT = 0; +static const int TBE_OUTPUT_BRANCH_SINGLE = 1; +static const int TBE_OUTPUT_BRANCH_MULTI = 2; +static const int TBE_PATTERN_GROUPID_INVALID = -1; + +enum SkipStatus { DISABLED = 0, AVAILABLE = 1, SKIPPED = 2 }; + +struct BufferFusionOpDesc { + std::string desc_name; // description name + std::vector types; // description type + std::vector inputs; // all input op + std::vector outputs; // all output op + int64_t out_branch_type; // out desc type, 1:single, 2: multi + int64_t repeate_min; // opdesc min repeat num + int64_t repeate_max; // opdesc max repeat num + int64_t repeate_curr; // opdesc current repeat num + bool match_status; + int64_t group_id; // record desc groupid, need one desc matched at least in + // the same group + bool ignore_input_num; + bool ignore_output_num; + // used for two connected op, first opdesc has optional multiple nodes and + // ignore_output_num is true, second opdesc is same pattern type and + // out_branch_type is TBE_OUTPUT_BRANCH_MULTI + std::map multi_output_skip_status; +}; +using BufferFusionMapping = std::map>; +using BufferFusionMappings = std::vector; + +class BufferFusionPattern { + public: + explicit BufferFusionPattern(std::string name = "", int64_t op_max_count = TBE_FUSION_OP_NUM_MAX); + + virtual ~BufferFusionPattern(); + + BufferFusionPattern &AddOpDesc(const std::string &desc_name, const std::vector &patterns, + int64_t repeat_min = TBE_PATTERN_NUM_DEFAULT, + int64_t repeat_max = TBE_PATTERN_NUM_DEFAULT, + int64_t group_id = TBE_PATTERN_GROUPID_INVALID); + + BufferFusionPattern &SetOutputs(const std::string &desc_name, const std::vector &patterns, + int64_t relation = TBE_OUTPUT_BRANCH_SINGLE, bool ignore_input_num = false, + bool ignore_output_num = false); + + BufferFusionPattern &SetHead(const std::vector &op_patterns); + + std::string GetName(); + int64_t GetOpMaxCount(); + std::vector GetOpDescs(); + bool GetOutputs(BufferFusionOpDesc *op_desc, std::vector &outputs, bool ignore_repeat = false); + std::vector GetHead(); + int64_t GetErrorCnt(); + void InitRepeatCurr(const BufferFusionPattern &pattern); + + private: + BufferFusionOpDesc *GetOpDesc(const std::string &desc_name); + void UpdateSkipStatus(BufferFusionOpDesc *op_desc); + std::string name_; + int64_t op_max_count_; + std::vector ops_; + std::map op_map_; + std::vector head_; + int64_t error_count_; +}; +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ diff --git a/metadef/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h b/metadef/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h new file mode 100644 index 00000000..0927ec36 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H + +#include +#include +#include +#include +#include +#include + +namespace fe { + +class FusionInfo { + public: + explicit FusionInfo(uint64_t session_id = 0, std::string graph_id = "", std::string pass_name = "", + int32_t match_times = 0, int32_t effect_times = 0); + + virtual ~FusionInfo(); + + void AddMatchTimes(int32_t match_times); + + void AddEffectTimes(int32_t effect_times); + + int32_t GetMatchTimes(); + + int32_t GetEffectTimes(); + + std::string GetGraphId(); + + std::string GetPassName(); + + uint64_t GetSessionId(); + + void SetMatchTimes(int32_t match_times); + + void SetEffectTimes(int32_t effect_times); + + private: + uint64_t session_id_; + std::string graph_id_; + std::string pass_name_; + int32_t match_times_; + int32_t effect_times_; +}; + +using FusionStatisticMap = std::map>; + +class FusionStatisticRecorder { + public: + FusionStatisticRecorder(const FusionStatisticRecorder &) = delete; + + FusionStatisticRecorder &operator=(const FusionStatisticRecorder &) = delete; + + static FusionStatisticRecorder &Instance(); + + void UpdateGraphFusionMatchTimes(FusionInfo &fusion_info); + + void UpdateGraphFusionEffectTimes(FusionInfo &fusion_info); + + void UpdateBufferFusionMatchTimes(FusionInfo &fusion_info); + + void UpdateBufferFusionEffectTimes(FusionInfo &fusion_info); + + void GetAndClearFusionInfo(const std::string &session_graph_id, + std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map); + + private: + FusionStatisticRecorder(); + virtual ~FusionStatisticRecorder(); + FusionStatisticMap graph_fusion_info_map_; + FusionStatisticMap buffer_fusion_info_map_; + void GetFusionInfo(const std::string &session_graph_id, std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map); + void ClearFusionInfo(std::string session_graph_id); + std::recursive_mutex mutex_; +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H diff --git a/metadef/inc/register/graph_optimizer/fusion_common/graph_pass_util.h b/metadef/inc/register/graph_optimizer/fusion_common/graph_pass_util.h new file mode 100644 index 00000000..a1617c43 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/fusion_common/graph_pass_util.h @@ -0,0 +1,224 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +#include +#include +#include +#include +#include + +namespace fe { +using NodeTypeMap = std::unordered_map>; +using NodeTypeMapPtr = std::shared_ptr; +struct NodeMapInfo { + int64_t run_count; + NodeTypeMapPtr node_type_map; +}; +using NodeMapInfoPtr = std::shared_ptr; + +/** @brief define graph pass, which provides two interface: 1. run pass; +* 2. record op names before fusion */ +class GraphPassUtil { + public: + /** set outputdesc attr for data dump + * + * @param origin_index,usually is origin node output index + * + * @param fusion_index,usually is fusion node output index + * + * @param origin_node, usually is origin node + * + * @param fusion_node, usually is fusion node + */ + static void SetOutputDescAttr(uint32_t origin_index, uint32_t fusion_index, ge::NodePtr origin_node, + ge::NodePtr fusion_node) { + if (fusion_node->GetOpDesc() == nullptr) { + return; + } + + auto fusion_node_output_desc = fusion_node->GetOpDesc()->MutableOutputDesc(fusion_index); + if (fusion_node_output_desc == nullptr) { + return; + } + if (origin_node->GetOpDesc() == nullptr) { + return; + } + auto origin_node_output_desc = origin_node->GetOpDesc()->MutableOutputDesc(origin_index); + if (origin_node_output_desc == nullptr) { + return; + } + + std::vector original_names; + if (ge::AttrUtils::GetListStr(origin_node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names) && + original_names.size() > 0) { + std::string original_name; + if (ge::AttrUtils::GetStr(origin_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, original_name)) { + (void)ge::AttrUtils::SetStr(fusion_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, original_name); + + std::int64_t origin_output_index = 0; + if (ge::AttrUtils::GetInt(origin_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, + origin_output_index)) { + (void)ge::AttrUtils::SetInt(fusion_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, + origin_output_index); + } + + ge::DataType origin_data_type = GetDataDumpOriginDataType(origin_node_output_desc); + if (origin_data_type != ge::DT_UNDEFINED) { + SetDataDumpOriginDataType(origin_data_type, fusion_node_output_desc); + } + ge::Format origin_format = GetDataDumpOriginFormat(origin_node_output_desc); + if (origin_format != ge::FORMAT_RESERVED) { + SetDataDumpOriginFormat(origin_format, fusion_node_output_desc); + } + } + } else { + (void)ge::AttrUtils::SetStr(fusion_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_node->GetName()); + (void)ge::AttrUtils::SetInt(fusion_node_output_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_index); + SetDataDumpOriginDataType(origin_node_output_desc->GetOriginDataType(), fusion_node_output_desc); + SetDataDumpOriginFormat(origin_node_output_desc->GetOriginFormat(), fusion_node_output_desc); + } + } + + /** get origin format for data dump + * + * @param tensor_desc,usually is output_desc + * + * @return format of this tensor_desc + */ + static ge::Format GetDataDumpOriginFormat(ge::GeTensorDescPtr tensor_desc) { + std::string origin_format_str; + if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format_str)) { + // Can not get the certificate and it's not set,return directly + return ge::FORMAT_RESERVED; + } + if (origin_format_str == "RESERVED") { + return ge::FORMAT_RESERVED; + } + return ge::TypeUtils::SerialStringToFormat(origin_format_str); + } + + /** set origin format for data dump + * + * @param origin format + * + * @param tensor_desc,usually is output_desc + */ + static void SetDataDumpOriginFormat(ge::Format origin_format, ge::GeTensorDescPtr tensor_desc) { + std::string origin_format_str = "RESERVED"; + if (origin_format != ge::FORMAT_RESERVED) { + origin_format_str = ge::TypeUtils::FormatToSerialString(origin_format); + } + (void)ge::AttrUtils::SetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format_str); + } + + /** set origin datatype for data dump + * + * @param origin datatype + * + * @param tensor_desc,usually is output_desc + */ + static void SetDataDumpOriginDataType(ge::DataType origin_data_type, ge::GeTensorDescPtr tensor_desc) { + std::string origin_data_type_str = "RESERVED"; + if (origin_data_type != ge::DT_UNDEFINED) { + origin_data_type_str = ge::TypeUtils::DataTypeToSerialString(origin_data_type); + } + (void)ge::AttrUtils::SetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_data_type_str); + } + + /** get origin datatype for data dump + * + * @param tensor_desc,usually is output_desc + * + * @return format of this tensor_desc + */ + static ge::DataType GetDataDumpOriginDataType(ge::GeTensorDescPtr tensor_desc) { + std::string origin_data_type_str; + if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_data_type_str)) { + return ge::DT_UNDEFINED; + } + if (origin_data_type_str == "RESERVED") { + return ge::DT_UNDEFINED; + } + return ge::TypeUtils::SerialStringToDataType(origin_data_type_str); + } + + static void AddNodeFromOpTypeMap(NodeMapInfoPtr &node_map_info, ge::NodePtr &node_ptr) { + if (node_map_info == nullptr || node_ptr == nullptr) { + return; + } + NodeTypeMapPtr node_type_map = node_map_info->node_type_map; + string real_op_type = ge::NodeUtils::GetNodeType(*node_ptr); + auto iter = node_type_map->find(real_op_type); + if (iter != node_type_map->end()) { + iter->second.insert(node_ptr); + } else { + node_type_map->emplace(std::make_pair(real_op_type, std::unordered_set{node_ptr})); + } + } + + static Status GetOpTypeMapToGraph(NodeMapInfoPtr &node_map_info, const ge::ComputeGraph &graph) { + node_map_info = graph.TryGetExtAttr("NodeMapInfo", node_map_info); + if (node_map_info == nullptr) { + return FAILED; + } + return SUCCESS; + } + + static void RecordOriginalNames(std::vector original_nodes, ge::NodePtr node) { + // 1. get the original_names + std::vector original_names; + for (ge::NodePtr original_node : original_nodes) { + if (original_node == nullptr || original_node->GetOpDesc() == nullptr) { + return; + } + + ge::OpDescPtr origin_op_desc_ptr = original_node->GetOpDesc(); + std::vector names_tmp; + bool is_has_attr = ge::AttrUtils::GetListStr(origin_op_desc_ptr, + ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, + names_tmp); + if (is_has_attr) { + for (const auto &node_name : names_tmp) { + if (!node_name.empty()) { + original_names.push_back(node_name); + } + } + } else { + original_names.push_back(origin_op_desc_ptr->GetName()); + } + } + + // 2. set the dump attr + if (node == nullptr || node->GetOpDesc() == nullptr) { + return; + } + ge::OpDescPtr node_op_desc_ptr = node->GetOpDesc(); + (void)ge::AttrUtils::SetListStr(node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); + } +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ diff --git a/metadef/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h b/metadef/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h new file mode 100644 index 00000000..f507ad4e --- /dev/null +++ b/metadef/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ + +#include +#include +#include +#include +#include +#include "common/opskernel/ops_kernel_info_store.h" +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" +#include "register/graph_optimizer/graph_fusion/graph_pass.h" + +using std::initializer_list; +using std::map; +using std::string; +using std::vector; + +using namespace std; + +namespace fe { +using OpsKernelInfoStorePtr = std::shared_ptr; +class PatternFusionBasePassImpl; +using PatternFusionBasePassImplPtr = std::shared_ptr; + +/** Pass based on pattern + * @ingroup FUSION_PASS_GROUP + * @note New virtual methods should be append at the end of this class + */ +class PatternFusionBasePass : public GraphPass { + public: + using OpDesc = FusionPattern::OpDesc; + using Mapping = map, vector>; + using Mappings = vector; + + PatternFusionBasePass(); + virtual ~PatternFusionBasePass(); + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + Status Run(ge::ComputeGraph &graph) override; + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @param [ops_kernel_info_store_ptr, OP info kernel instance + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph, OpsKernelInfoStorePtr ops_kernel_info_store_ptr); + + protected: + virtual vector DefinePatterns() = 0; + virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) = 0; + + std::vector GetNodesFromMapping(const Mapping &mapping); + ge::NodePtr GetNodeFromMapping(const string &id, const Mapping &mapping); + + void RecordOutputAnchorMap(ge::NodePtr output_node); + void ClearOutputAnchorMap(); + + Status SetDataDumpAttr(vector &original_nodes, vector &fus_nodes); + + bool CheckOpSupported(const ge::OpDescPtr &op_desc_ptr); + + private: + /** match all nodes in graph according to pattern + * + * @param pattern fusion pattern defined + * @param mappings match result + * @return SUCCESS, successfully add edge + * @return FAILED, fail + */ + bool MatchAll(ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings); + + Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); // lint !e148 + + /** Internal implement class ptr */ + std::shared_ptr pattern_fusion_base_pass_impl_ptr_; + + std::unordered_map> origin_op_anchors_map_; +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h b/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h new file mode 100644 index 00000000..ba37ffaf --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PASS_REGISTRY_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PASS_REGISTRY_H_ + +#include +#include +#include +#include +#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" + +namespace fe { +class FusionPassRegistry { + public: + using CreateFn = GraphPass *(*)(); + ~FusionPassRegistry(); + + static FusionPassRegistry &GetInstance(); + + void RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, CreateFn create_fn); + + std::map GetCreateFnByType(const GraphFusionPassType &pass_type); + + private: + FusionPassRegistry(); + class FusionPassRegistryImpl; + std::unique_ptr impl_; +}; + +class FusionPassRegistrar { + public: + FusionPassRegistrar(const GraphFusionPassType &pass_type, const std::string &pass_name, GraphPass *(*create_fun)()); + ~FusionPassRegistrar() {} +}; + +#define REGISTER_PASS(pass_name, pass_type, pass_class) \ + REGISTER_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class) + +#define REGISTER_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class) \ + REGISTER_PASS_UNIQ(ctr, pass_name, pass_type, pass_class) + +#define REGISTER_PASS_UNIQ(ctr, pass_name, pass_type, pass_class) \ + static ::fe::FusionPassRegistrar register_fusion_pass##ctr __attribute__((unused)) = ::fe::FusionPassRegistrar( \ + pass_type, pass_name, []() -> ::fe::GraphPass * { return new (std::nothrow) pass_class(); }) + +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PASS_REGISTRY_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h b/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h new file mode 100644 index 00000000..7de5eb78 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h @@ -0,0 +1,173 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ +#include +#include +#include +#include +#include + +using std::initializer_list; +using std::map; +using std::string; +using std::vector; + +using namespace std; + +namespace fe { + +/** Fusion pattern + * @ingroup FUSION_PASS_GROUP + * Describe Pattern of Ops waiting for fusion(Op type, etc) + */ +class FusionPattern { + public: + struct OpDesc; + using OpDescPtr = std::shared_ptr; + /** + * @ingroup fe + * @brief description of Ops + */ + struct OpDesc { + string id; // Identifier + std::vector types; // the Op types of Ops + std::vector inputs; // all input Ops + bool repeatable; // flag to show if match multiple Ops or not + bool is_output; // flag to show if the op is output node + }; + + public: + explicit FusionPattern(string name = ""); + ~FusionPattern(); + + /** set pattern name + * + * @param name pattern name + * @return FusionPattern + */ + FusionPattern &SetName(const string &name); + + /** add Op description with unknown number of args + * + * @param id pattern id + * @param types op type list + * @return FusionPattern + */ + FusionPattern &AddOpDesc(const string &id, const initializer_list &types = {}); + + /** add Op description with vector + * + * @param id pattern id + * @param types op type list + * + * @return FusionPattern + */ + FusionPattern &AddOpDesc(const string &id, const vector &types); + + /** set input Ops with unknown number of args + * + * @param id pattern id + * + * @param input_ids inputs to id op + * + * @return FusionPattern + */ + FusionPattern &SetInputs(const string &id, const initializer_list &input_ids); + + /** set input Ops with unknown number of args + * + * @param id pattern id + * + * @param input_ids inputs to id op + * + * @return FusionPattern + */ + FusionPattern &SetInputs(const string &id, const vector &input_ids); + + /** set output Op + * + * @param id pattern id + * + * @return FusionPattern + */ + FusionPattern &SetOutput(const string &id); + + /** build pattern and check if error exists + * + * @return True or False + */ + bool Build(); + + /** get pattern name + * + * @param id pattern id + * + * @return fusion pattern name + */ + const string &GetName() const; + + /** get the OpDesc of input Ops (const) + * + * @param op_desc op_desc for getting inputs + * + * @return op_desc's iniput opdesc list + */ + static const vector> *GetInputs(std::shared_ptr op_desc); + + /** get the OpDesc of output Op + * + * @return pattern's output opdesc list + */ + const std::shared_ptr GetOutput() const; + + /** print pattern + * + */ + void Dump() const; + + void GetOpDescList(vector> &op_desc_list); + + /** get OpDesc based on ID, return nullptr if failed + * + * @param id pattern id + * + * @return pattern's output opdesc list + */ + std::shared_ptr GetOpDesc(const string &id) const; + + private: + FusionPattern(const FusionPattern &) = default; + FusionPattern &operator=(const FusionPattern &) = default; + + void SetError(); + + private: + string name_; + + vector> ops_; + + map> op_map_; + + std::shared_ptr output_; + + bool has_error_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h b/metadef/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h new file mode 100644 index 00000000..bd957cf8 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h @@ -0,0 +1,113 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ + +#include +#include +#include +#include +#include + +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" +#include "register/graph_optimizer/graph_fusion/graph_pass.h" + +using std::initializer_list; +using std::map; +using std::string; +using std::vector; + +using namespace std; + +namespace fe { +enum GraphFusionPassType { + BUILT_IN_GRAPH_PASS = 0, + BUILT_IN_VECTOR_CORE_GRAPH_PASS, + CUSTOM_AI_CORE_GRAPH_PASS, + CUSTOM_VECTOR_CORE_GRAPH_PASS, + SECOND_ROUND_BUILT_IN_GRAPH_PASS, + BUILT_IN_BEFORE_TRANSNODE_INSERTION_GRAPH_PASS, + GRAPH_FUSION_PASS_TYPE_RESERVED +}; +class PatternFusionBasePassImpl; +using PatternFusionBasePassImplPtr = std::shared_ptr; + +/** Pass based on pattern + * @ingroup FUSION_PASS_GROUP + * @note New virtual methods should be append at the end of this class + */ +class GraphFusionPassBase : public GraphPass { + public: + using OpDesc = FusionPattern::OpDesc; + using Mapping = map, vector>; + using Mappings = vector; + + GraphFusionPassBase(); + virtual ~GraphFusionPassBase(); + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + Status Run(ge::ComputeGraph &graph) override; + + protected: + /** define pattern + * + * @return NA + */ + virtual vector DefinePatterns() = 0; + + /** do fusion according to nodes matched + * + * @param graph the graph waiting for pass level optimization + * @param new_nodes fusion result node(s) + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) = 0; // lint !e148 + + /** get nodes from matched result + * + * @param mapping match result + * @return nodes result + */ + static ge::NodePtr GetNodeFromMapping(const string &id, const Mapping &mapping); + + private: + /** match all nodes in graph according to pattern + * + * @param pattern fusion pattern defined + * @param mappings match result + * @return SUCCESS, successfully add edge + * @return FAILED, fail + */ + bool MatchAll(ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings); + + Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); // lint !e148 + + /** Internal implement class ptr */ + std::shared_ptr pattern_fusion_base_pass_impl_ptr_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_fusion/graph_pass.h b/metadef/inc/register/graph_optimizer/graph_fusion/graph_pass.h new file mode 100644 index 00000000..dc4c6640 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_fusion/graph_pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ + +#include +#include "register/graph_optimizer/graph_fusion/pass.h" + +namespace fe { + +/** graph pass + * @ingroup GRAPH_PASS_GROUP + * graph level pass + */ +class GraphPass : public Pass { + public: + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) = 0; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_fusion/pass.h b/metadef/inc/register/graph_optimizer/graph_fusion/pass.h new file mode 100644 index 00000000..4c1ebd60 --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_fusion/pass.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** @defgroup FUSION_PASS_GROUP Fusion Pass Interface */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ + +#include "graph/compute_graph.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { + +/** fusion pass + * @ingroup GRAPH_PASS_GROUP + * network level pass + */ +template +class Pass { + public: + virtual ~Pass() {} + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) = 0; + + void SetName(const string &name) { name_ = name; } + + string GetName() { return name_; } + + private: + string name_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ diff --git a/metadef/inc/register/graph_optimizer/graph_optimize_register_error_codes.h b/metadef/inc/register/graph_optimizer/graph_optimize_register_error_codes.h new file mode 100644 index 00000000..c6d95a9c --- /dev/null +++ b/metadef/inc/register/graph_optimizer/graph_optimize_register_error_codes.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ +#define INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ + +#include +#include + +/** Assigned SYS ID */ +const uint8_t SYSID_FE = 3; + +/** Common module ID */ +const uint8_t FE_MODID_COMMON = 50; + +namespace fe { + +/** FE error code definiton Macro +* Build error code +*/ +#define FE_DEF_ERRORNO(sysid, modid, name, value, desc) \ + static constexpr fe::Status name = \ + (((((uint32_t)(0xFF & ((uint8_t)(sysid)))) << 24) | (((uint32_t)(0xFF & ((uint8_t)(modid)))) << 16)) | \ + (0xFFFF & ((uint16_t)(value)))); + +using Status = uint32_t; + +#define FE_DEF_ERRORNO_COMMON(name, value, desc) FE_DEF_ERRORNO(SYSID_FE, FE_MODID_COMMON, name, value, desc) + +using Status = uint32_t; + +FE_DEF_ERRORNO(0, 0, SUCCESS, 0, "success"); +FE_DEF_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFF, "failed"); +FE_DEF_ERRORNO_COMMON(NOT_CHANGED, 201, "The nodes of the graph not changed."); +FE_DEF_ERRORNO_COMMON(PARAM_INVALID, 1, "Parameter's invalid!"); + +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ diff --git a/metadef/inc/register/host_cpu_context.h b/metadef/inc/register/host_cpu_context.h new file mode 100644 index 00000000..f7d4f52f --- /dev/null +++ b/metadef/inc/register/host_cpu_context.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_HOST_CPU_CONTEXT_H_ +#define INC_REGISTER_HOST_CPU_CONTEXT_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "register/register_types.h" + +namespace ge { +class HostCpuContext { + public: + HostCpuContext() = default; + ~HostCpuContext() = default; + private: + class Impl; + Impl *impl_; +}; +} // namespace ge + +extern "C" { +// Unified definition for registering host_cpu_kernel_wrapper when so is opened +FMK_FUNC_HOST_VISIBILITY ge::Status Initialize(const ge::HostCpuContext &ctx); +} + +#endif //INC_REGISTER_HOST_CPU_CONTEXT_H_ diff --git a/metadef/inc/register/infer_data_slice_registry.h b/metadef/inc/register/infer_data_slice_registry.h new file mode 100644 index 00000000..d9ccb34e --- /dev/null +++ b/metadef/inc/register/infer_data_slice_registry.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ +#define INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ + +#include "external/graph/ge_error_codes.h" +#include "external/graph/operator.h" + +namespace ge { +using InferDataSliceFunc = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferDataSliceFuncRegister { + public: + InferDataSliceFuncRegister(const char *operator_type, const InferDataSliceFunc &infer_data_slice_func); + ~InferDataSliceFuncRegister() = default; +}; + +// infer data slice func register +#define IMPLEMT_INFER_DATA_SLICE(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +#define INFER_DATA_SLICE_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } + +#define __INFER_DATA_SLICE_FUNC_REG_IMPL__(op_name, x, n) \ + static const InferDataSliceFuncRegister PASTE(ids_register, n)(#op_name, x) + +#define INFER_DATA_SLICE_FUNC_REG(op_name, x) \ + __INFER_DATA_SLICE_FUNC_REG_IMPL__(op_name, INFER_DATA_SLICE_FUNC(op_name, x), __COUNTER__) +} // namespace ge + +#endif // INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ diff --git a/metadef/inc/register/op_kernel_registry.h b/metadef/inc/register/op_kernel_registry.h new file mode 100644 index 00000000..5fed8960 --- /dev/null +++ b/metadef/inc/register/op_kernel_registry.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_OP_KERNEL_REGISTRY_H_ +#define INC_REGISTER_OP_KERNEL_REGISTRY_H_ +#include +#include +#include "register/register_types.h" +#include "register.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { + public: + using CreateFn = HostCpuOp* (*)(); + ~OpKernelRegistry(); + + static OpKernelRegistry& GetInstance() { + static OpKernelRegistry instance; + return instance; + } + + bool IsRegistered(const std::string &op_type); + + void RegisterHostCpuOp(const std::string &op_type, CreateFn create_fn); + + std::unique_ptr CreateHostCpuOp(const std::string &op_type); + + private: + OpKernelRegistry(); + class OpKernelRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; +}; +} // namespace ge + +#endif // INC_REGISTER_OP_KERNEL_REGISTRY_H_ diff --git a/metadef/inc/register/op_registry.h b/metadef/inc/register/op_registry.h new file mode 100644 index 00000000..318eb3ba --- /dev/null +++ b/metadef/inc/register/op_registry.h @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_OP_REGISTRY_H_ +#define INC_REGISTER_OP_REGISTRY_H_ + +#include +#include +#include +#include +#include + +#include "register/register.h" + +namespace domi { +enum RemoveInputType { + OMG_MOVE_TYPE_DTYPE = 0, + OMG_MOVE_TYPE_VALUE, + OMG_MOVE_TYPE_SHAPE, + OMG_MOVE_TYPE_FORMAT, + OMG_MOVE_TYPE_AXIS, + OMG_MOVE_TYPE_SCALAR_VALUE, + OMG_REMOVE_TYPE_WITH_COND = 1000, + OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE, + OMG_INPUT_REORDER, +}; + +struct RemoveInputConfigure { + int inputIdx = INT_MAX; + std::string attrName; + RemoveInputType moveType; + bool attrValue = false; + std::string originalType; + std::vector input_order; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { + public: + static OpRegistry *Instance(); + + std::vector registrationDatas; + + bool Register(const OpRegistrationData ®_data); + + domi::ImplyType GetImplyType(const std::string &op_type); + + void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType &imply_type); + + domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); + + domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); + + domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); + + domi::FusionParseParamByOpFunc GetFusionParseParamByOpFunc(const std::string &op_type, + const std::string &ori_type); + + domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); + + Status GetParseSubgraphPostFunc(const std::string &op_type, domi::ParseSubgraphFuncV2 &parse_subgraph_func); + + domi::ImplyType GetImplyTypeByOriOpType(const std::string &ori_optype); + + const std::vector &GetRemoveInputConfigure(const std::string &ori_optype) const; + + bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); + + ParseOpToGraphFunc GetParseOpToGraphFunc(const std::string &op_type, const std::string &ori_type); + + private: + std::unordered_map op_run_mode_map_; + std::unordered_map op_parse_params_fn_map_; + std::unordered_map parse_params_by_op_func_map_; + std::unordered_map fusion_op_parse_params_fn_map_; + std::unordered_map fusion_parse_params_by_op_fn_map_; + std::unordered_map op_types_to_parse_subgraph_post_func_; + std::unordered_map> remove_input_configure_map_; + std::unordered_map origin_type_to_om_type_; + std::unordered_map parse_op_to_graph_fn_map_; + std::unordered_map op_types_to_parse_subgraph_post_func_v2_; +}; +} // namespace domi +#endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/metadef/inc/register/op_tiling.h b/metadef/inc/register/op_tiling.h new file mode 100644 index 00000000..ae26c378 --- /dev/null +++ b/metadef/inc/register/op_tiling.h @@ -0,0 +1,31 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_OP_TILING_H_ +#define INC_REGISTER_OP_TILING_H_ + +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" +#include "register/op_tiling_registry.h" + +namespace optiling { + +extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info); +extern "C" ge::graphStatus OpAtomicCalculate(const ge::Node &node, OpRunInfo &run_info); + +} // namespace optiling + +#endif // INC_REGISTER_OP_TILING_H_ diff --git a/metadef/inc/register/ops_kernel_builder_registry.h b/metadef/inc/register/ops_kernel_builder_registry.h new file mode 100644 index 00000000..8a8f3a18 --- /dev/null +++ b/metadef/inc/register/ops_kernel_builder_registry.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H_ +#define INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H_ + +#include +#include "register/register_types.h" +#include "common/opskernel/ops_kernel_builder.h" + +namespace ge { +using OpsKernelBuilderPtr = std::shared_ptr; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistry { + public: + ~OpsKernelBuilderRegistry(); + static OpsKernelBuilderRegistry &GetInstance(); + + void Register(const std::string &lib_name, const OpsKernelBuilderPtr &instance); + + void Unregister(const std::string &lib_name); + + void UnregisterAll(); + + const std::map &GetAll() const; + + private: + OpsKernelBuilderRegistry() = default; + std::map kernel_builders_; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistrar { + public: + using CreateFn = OpsKernelBuilder *(*)(); + OpsKernelBuilderRegistrar(const std::string &kernel_lib_name, CreateFn fn); + ~OpsKernelBuilderRegistrar(); + +private: + std::string kernel_lib_name_; +}; + +#define REGISTER_OPS_KERNEL_BUILDER(kernel_lib_name, builder) \ + REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_lib_name, builder) + +#define REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_lib_name, builder) \ + REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) + +#define REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) \ + static ::ge::OpsKernelBuilderRegistrar register_op_kernel_builder_##ctr \ + __attribute__((unused)) = \ + ::ge::OpsKernelBuilderRegistrar(kernel_lib_name, []()->::ge::OpsKernelBuilder* { \ + return new (std::nothrow) builder(); \ + }) +} // namespace ge + +#endif // INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H_ diff --git a/metadef/inc/register/proto/caffe/caffe.proto b/metadef/inc/register/proto/caffe/caffe.proto new file mode 100644 index 00000000..3f45aae2 --- /dev/null +++ b/metadef/inc/register/proto/caffe/caffe.proto @@ -0,0 +1,1821 @@ +syntax = "proto2"; + +package domi.caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + optional bytes int8_data = 10; + repeated int32 int32_data = 11 [packed = true]; + repeated uint64 uint64_data = 12 [packed = true]; + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; + optional ROIPoolingParameter roi_pooling_param = 161; + optional YoloParameter yolo_param = 199; + optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; + optional ProposalParameter proposal_param = 201; + optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; + optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; + optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; + optional QuantParameter quant_param = 208; + optional CondTakeParameter condtake_param = 233; + optional MatrixInverseParameter matrix_inverse_param = 210; + optional WarpPerspectiveParameter warp_perspective_param = 234; + optional BatchMatMulParameter batch_matmul_param = 235; + optional SpatialTransformerParameter st_param = 5000; + optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + optional bool bias_from_blob = 4 [default = true]; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; + optional bool scale_from_blob = 6 [default = true]; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + QUANT = 208; + DEQUANT = 209; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; + optional float objectness_score = 22; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; + optional int32 roi_end_mode = 5 [default = 0]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; // Category of classification + optional uint32 coords = 2 [default = 4]; // Coordinates of box + optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + repeated int32 axis = 1; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional float scale = 1[default = 1]; + optional int32 stride = 2[default = 2]; + optional int32 stride_h = 3[default = 2]; + optional int32 stride_w = 4[default=2]; +} +message ROIPoolingParameter { + required int32 pooled_h = 1; + required int32 pooled_w = 2; + optional float spatial_scale = 3 [default=0.0625]; + optional float spatial_scale_h = 4; + optional float spatial_scale_w = 5; +} + +message YoloParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 coords = 2 [default = 4]; + optional int32 classes = 3 [default = 80]; + optional string yolo_version = 4 [default = "V3"]; + optional bool softmax = 5 [default = false]; + optional bool background = 6 [default = false]; + optional bool softmaxtree = 7 [default = false]; +} + +message YoloV3DetectionOutputParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; +} + +message YoloV3DetectionOutputV2Parameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; + optional int32 out_box_dim = 15 [default = 3]; +} + +message ProposalParameter { + optional float feat_stride = 1 [default = 16]; + optional float base_size = 2 [default = 16]; + optional float min_size = 3 [default = 16]; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6 [default = 3000]; + optional int32 post_nms_topn = 7 [default = 304]; + optional float iou_threshold = 8 [default = 0.7]; + optional bool output_actual_rois_num = 9 [default = false]; +} + +message FSRDetectionOutputParameter { + required int32 num_classes = 1; + required float score_threshold = 2; + required float iou_threshold = 3; + optional int32 batch_rois = 4 [default = 1]; +} + +message SSDDetectionOutputParameter { + required int32 num_classes= 1 [default = 2]; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional float iou_threshold = 4 [default = 0.3]; + optional int32 top_k = 5 [default = 200]; + optional float eta = 6 [default = 1.0]; + optional bool variance_encoded_in_target = 7 [default = false]; + optional int32 code_type = 8 [default = 1]; + optional int32 keep_top_k = 9 [default = -1]; + optional float confidence_threshold = 10 [default = 0.0]; +} +message YoloV2DetectionOutputParameter { + optional int32 boxes = 1 [default = 5]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases = 9; + optional int32 coords = 10 [default = 4]; + optional bool resize_origin_img_to_net = 11 [default = false]; +} + +message QuantParameter { + optional float scale = 2; + optional bytes offset = 3; +} + +message BatchMatMulParameter{ + optional bool adj_x1 = 1 [default = false]; + optional bool adj_x2 = 2 [default = false]; +} + +message CondTakeParameter { + required string mode = 1; + required float val = 2; + optional float eps = 3 [default = 1e-06]; +} + +message MatrixInverseParameter { + optional bool adjoint = 1 [default = false]; +} + +message WarpPerspectiveParameter { + required int32 out_height = 1; + required int32 out_width = 2; + optional float constant = 3; + optional string border_type = 4 [default = 'BORDER_CONSTANT']; +} + +message SpatialTransformerParameter { + // How to use the parameter passed by localisation network + optional string transform_type = 1 [default = "affine"]; + // What is the sampling technique + optional string sampler_type = 2 [default = "bilinear"]; + + // If not set,stay same with the input dimension H and W + optional int32 output_H = 3; + optional int32 output_W = 4; + // If false, only compute dTheta, DO NOT compute dU + optional bool to_compute_dU = 5 [default = true]; + + // The default value for some parameters + optional double theta_1_1 = 6; + optional double theta_1_2 = 7; + optional double theta_1_3 = 8; + optional double theta_2_1 = 9; + optional double theta_2_2 = 10; + optional double theta_2_3 = 11; +} diff --git a/metadef/inc/register/proto/dump_task.proto b/metadef/inc/register/proto/dump_task.proto new file mode 100644 index 00000000..b1e346cd --- /dev/null +++ b/metadef/inc/register/proto/dump_task.proto @@ -0,0 +1,111 @@ +syntax = "proto3"; +package toolkit.dumpdata; + +enum OutputDataType { + DT_UNDEFINED = 0; + DT_FLOAT = 1; + DT_FLOAT16 = 2; + DT_INT8 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_UINT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT32 = 9; + DT_UINT64 = 10; + DT_BOOL = 11; + DT_DOUBLE = 12; + DT_STRING = 13; + DT_DUAL_SUB_INT8 = 14; + DT_DUAL_SUB_UINT8 = 15; + DT_COMPLEX64 = 16; + DT_COMPLEX128 = 17; + DT_QINT8 = 18; + DT_QINT16 = 19; + DT_QINT32 = 20; + DT_QUINT8 = 21; + DT_QUINT16 = 22; + DT_RESOURCE = 23; + DT_STRING_REF = 24; + DT_DUAL = 25; +} + +enum OutputFormat { + FORMAT_NCHW = 0; + FORMAT_NHWC = 1; + FORMAT_ND = 2; + FORMAT_NC1HWC0 = 3; + FORMAT_FRACTAL_Z = 4; + FORMAT_NC1C0HWPAD = 5; + FORMAT_NHWC1C0 = 6; + FORMAT_FSR_NCHW = 7; + FORMAT_FRACTAL_DECONV = 8; + FORMAT_C1HWNC0 = 9; + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; + FORMAT_NC1HWC0_C04 = 12; + FORMAT_FRACTAL_Z_C04 = 13; + FORMAT_CHWN = 14; + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; + FORMAT_HWCN = 16; + FORMAT_NC1KHKWHWC0 = 17; + FORMAT_BN_WEIGHT = 18; + FORMAT_FILTER_HWCK = 19; + FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; + FORMAT_HASHTABLE_LOOKUP_KEYS = 21; + FORMAT_HASHTABLE_LOOKUP_VALUE = 22; + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; + FORMAT_HASHTABLE_LOOKUP_HITS=24; + FORMAT_C1HWNCoC0 = 25; + FORMAT_MD = 26; + FORMAT_NDHWC = 27; + FORMAT_FRACTAL_ZZ = 28; + FORMAT_FRACTAL_NZ = 29; + FORMAT_RESERVED = 30; +} + +message OriginalOp { + string name = 1; + uint32 output_index = 2; + OutputDataType data_type = 3; + OutputFormat format = 4; +} + +message Shape { + repeated uint64 dim = 1; +} + +message OpOutput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + OriginalOp original_op = 4; // the original op corresponding to the output + bytes data = 5; + uint64 size = 6; +} + +message OpInput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + bytes data = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + bytes data = 2; + uint64 size = 3; +} + +message DumpData{ + string version = 1; + uint64 dump_time = 2; + repeated OpOutput output = 3; + repeated OpInput input = 4; + repeated OpBuffer buffer = 5; +} diff --git a/metadef/inc/register/proto/fusion_model.proto b/metadef/inc/register/proto/fusion_model.proto new file mode 100644 index 00000000..c92c5581 --- /dev/null +++ b/metadef/inc/register/proto/fusion_model.proto @@ -0,0 +1,21 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +import "om.proto"; + +package domi; + +message FusionModelDef { + string version = 1; + repeated OpDef fusion_op = 2; +} \ No newline at end of file diff --git a/metadef/inc/register/proto/fwk_adapter.proto b/metadef/inc/register/proto/fwk_adapter.proto new file mode 100644 index 00000000..9335c926 --- /dev/null +++ b/metadef/inc/register/proto/fwk_adapter.proto @@ -0,0 +1,37 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package aicpu.FWKAdapter; +option cc_enable_arenas = true; + + +// Defines an struct for input and output. +message TensorDataInfo { + + // value DataType + uint32 dtype = 1; + + // shape dim + repeated int64 dim = 2; + + // data point addr + int64 data_addr = 3; +} + +message KernelRunParam { + // input + repeated TensorDataInfo input = 1; + // output + repeated TensorDataInfo output = 2; +} + diff --git a/metadef/inc/register/proto/ge_ir.proto b/metadef/inc/register/proto/ge_ir.proto new file mode 100644 index 00000000..e7bfe0cb --- /dev/null +++ b/metadef/inc/register/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/metadef/inc/register/proto/insert_op.proto b/metadef/inc/register/proto/insert_op.proto new file mode 100644 index 00000000..bf918b20 --- /dev/null +++ b/metadef/inc/register/proto/insert_op.proto @@ -0,0 +1,139 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // related_input_name is optional and the top name of data node which inserts aipp + string related_input_name = 6; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/metadef/inc/register/proto/om.proto b/metadef/inc/register/proto/om.proto new file mode 100644 index 00000000..e15e5f80 --- /dev/null +++ b/metadef/inc/register/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/metadef/inc/register/proto/onnx/ge_onnx.proto b/metadef/inc/register/proto/onnx/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/inc/register/proto/onnx/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/inc/register/proto/op_mapping_info.proto b/metadef/inc/register/proto/op_mapping_info.proto new file mode 100644 index 00000000..e23b7ebe --- /dev/null +++ b/metadef/inc/register/proto/op_mapping_info.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/metadef/inc/register/proto/proto_inner/ge_onnx.proto b/metadef/inc/register/proto/proto_inner/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/inc/register/proto/proto_inner/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/inc/register/proto/task.proto b/metadef/inc/register/proto/task.proto new file mode 100644 index 00000000..d0c09840 --- /dev/null +++ b/metadef/inc/register/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/metadef/inc/register/proto/tensorflow/attr_value.proto b/metadef/inc/register/proto/tensorflow/attr_value.proto new file mode 100644 index 00000000..1cc67d62 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/metadef/inc/register/proto/tensorflow/function.proto b/metadef/inc/register/proto/tensorflow/function.proto new file mode 100644 index 00000000..075897c6 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/metadef/inc/register/proto/tensorflow/graph.proto b/metadef/inc/register/proto/tensorflow/graph.proto new file mode 100644 index 00000000..d639a7d6 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/metadef/inc/register/proto/tensorflow/graph_library.proto b/metadef/inc/register/proto/tensorflow/graph_library.proto new file mode 100644 index 00000000..e393d38d --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/metadef/inc/register/proto/tensorflow/node_def.proto b/metadef/inc/register/proto/tensorflow/node_def.proto new file mode 100644 index 00000000..b9bc97ee --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/metadef/inc/register/proto/tensorflow/op_def.proto b/metadef/inc/register/proto/tensorflow/op_def.proto new file mode 100644 index 00000000..3485d045 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/metadef/inc/register/proto/tensorflow/resource_handle.proto b/metadef/inc/register/proto/tensorflow/resource_handle.proto new file mode 100644 index 00000000..a3452351 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/metadef/inc/register/proto/tensorflow/tensor.proto b/metadef/inc/register/proto/tensorflow/tensor.proto new file mode 100644 index 00000000..d0a4d024 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/metadef/inc/register/proto/tensorflow/tensor_shape.proto b/metadef/inc/register/proto/tensorflow/tensor_shape.proto new file mode 100644 index 00000000..4225a2e3 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/metadef/inc/register/proto/tensorflow/types.proto b/metadef/inc/register/proto/tensorflow/types.proto new file mode 100644 index 00000000..ba7a72b3 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/metadef/inc/register/proto/tensorflow/versions.proto b/metadef/inc/register/proto/tensorflow/versions.proto new file mode 100644 index 00000000..48061218 --- /dev/null +++ b/metadef/inc/register/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/metadef/inc/register/register.h b/metadef/inc/register/register.h new file mode 100644 index 00000000..72e9924d --- /dev/null +++ b/metadef/inc/register/register.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_REGISTRY_H_ +#define INC_REGISTER_REGISTRY_H_ + +#include "external/register/register.h" +#include "external/ge/ge_api_error_codes.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { + public: + HostCpuOp() = default; + virtual ~HostCpuOp() = default; + + virtual graphStatus Compute(Operator &op, + const std::map &inputs, + std::map &outputs) = 0; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { + public: + HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); + ~HostCpuOpRegistrar() = default; +}; + +#define REGISTER_HOST_CPU_OP_BUILDER(name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ + static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr \ + __attribute__((unused)) = \ + ::ge::HostCpuOpRegistrar(name, []()->::ge::HostCpuOp* { \ + return new (std::nothrow) op(); \ + }) +} // namespace ge + +#endif //INC_REGISTER_REGISTRY_H_ diff --git a/metadef/inc/register/register_format_transfer.h b/metadef/inc/register/register_format_transfer.h new file mode 100644 index 00000000..5cbf4ab4 --- /dev/null +++ b/metadef/inc/register/register_format_transfer.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ +#define INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ + +#include +#include +#include + +#include "external/graph/types.h" +#include "ge/ge_api_error_codes.h" + +namespace ge { +namespace formats { +struct TransArgs { + const uint8_t *data; + Format src_format; + Format dst_format; + // For scenes that need to supplement the shape, for example, 5D to 4D + // It is not possible to convert the format normally if you only get the src_shape, + // and must get the shape before you mend the shape. + // So the parameters here need to be passed in both src_shape and dst_shape + std::vector src_shape; + std::vector dst_shape; + DataType src_data_type; +}; + +struct TransResult { + std::shared_ptr data; + // data length in bytes + size_t length; +}; + +class FormatTransfer { + public: + virtual ~FormatTransfer() = default; + virtual Status TransFormat(const TransArgs &args, TransResult &result) = 0; + virtual Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, + Format dst_format, std::vector &dst_shape) = 0; +}; + +using FormatTransferBuilder = std::function()>; + +class FormatTransferRegister { + public: + FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst); + ~FormatTransferRegister() = default; +}; + +#define REGISTER_FORMAT_TRANSFER(TransferClass, format1, format2) \ + namespace { \ + FormatTransferRegister format_transfer_register_##TransferClass##format1##format2( \ + []() { return std::make_shared(); }, format1, format2); \ + } + +/// Build a formattransfer according to 'args' +/// @param args +/// @param result +/// @return +std::shared_ptr BuildFormatTransfer(const TransArgs &args); + +bool FormatTransferExists(const TransArgs &args); +} // namespace formats +} // namespace ge +#endif // INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ \ No newline at end of file diff --git a/metadef/inc/register/scope/scope_graph_impl.h b/metadef/inc/register/scope/scope_graph_impl.h new file mode 100644 index 00000000..ffc176e0 --- /dev/null +++ b/metadef/inc/register/scope/scope_graph_impl.h @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H_ +#define REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H_ + +#include "external/register/scope/scope_fusion_pass_register.h" +#include "graph/operator_factory.h" +#include "proto/tensorflow/graph.pb.h" +#include "proto/tensorflow/node_def.pb.h" +#include "graph/utils/type_utils.h" + +namespace ge { +using FusionInnerNodesInfo = std::vector>, // inputs + std::vector>, // outputs + const ge::Operator *>>; // operator + +class Scope::ScopeImpl { + public: + ScopeImpl() : father_scope_(nullptr) {} + Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); + ~ScopeImpl(); + + const std::string &Name() const { return name_; } + const std::string &SubType() const { return sub_type_; } + void SetSubType(const std::string &sub_type) { sub_type_ = sub_type; } + void ClearTypeAndSubType(); + void AddNode(ge::OperatorPtr &node_def); + const std::vector &Nodes() const { return nodes_; } + const std::unordered_map &AllNodesMap(); + void AddSubScope(Scope *scope) { sub_scopes_[scope->Name()] = scope; } + Scope *GetSubScope(const std::string &scope_name) const; + const std::unordered_map &GetSubScopes() const { return sub_scopes_; } + const std::vector &GetAllSubScopes(); + int32_t GetOpTypeNum(const std::string &op_type) const; + void OpsNumInc(const std::string &op_type); + const std::string LastName() const; + const Scope *GetFatherScope() const { return father_scope_; } + // trim scope_index + static std::string TrimScopeIndex(const std::string &scope_name); + + private: + std::string name_; + std::string sub_type_; + Scope *father_scope_; + std::unordered_map op_nums_; + std::unordered_map sub_scopes_; + std::vector nodes_; + std::unordered_map all_nodes_map_; + std::vector all_sub_scopes_; +}; + +class FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl { + public: + explicit InnerNodeInfoImpl(const std::string &fusion_node_name) : fusion_node_name_(fusion_node_name) {} + InnerNodeInfoImpl(const std::string &fusion_node_name, const std::string &name, const std::string &type) + : fusion_node_name_(fusion_node_name), name_(name), type_(type) { + SetName(name); + } + ~InnerNodeInfoImpl(); + std::string GetFullNodeName(const std::string &relative_name); + void SetName(const std::string &name) { name_ = GetFullNodeName(name); } + void SetType(const std::string &type) { type_ = type; } + void InsertInput(const std::string &input_node, int32_t peer_out_idx); + void InsertOutput(const std::string &output_node, int32_t peer_in_idx); + ge::graphStatus BuildOperator(); + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format) ; + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); + std::string GetName() const { return name_; } + std::string GetType() const { return type_; } + std::vector> GetInputs() const { return inner_node_inputs_; } + std::vector> GetOutputs() const { return inner_node_outputs_; } + ge::Operator *MutableOperator() { return &operator_; } + + public: + ge::Operator operator_; + private: + std::string fusion_node_name_; + std::string name_; + std::string type_; + std::vector> inner_node_inputs_; + std::vector> inner_node_outputs_; +}; + +class FusionScopesResult::FusionScopesResultImpl { + public: + FusionScopesResultImpl() {} + ~FusionScopesResultImpl(){}; + void SetName(const std::string &name) { name_ = name; } + void SetType(const std::string &type) { type_ = type; } + void SetDescription(const std::string &description) { description_ = description; } + const std::string &Name() const { return name_; } + const std::string &Type() const { return type_; } + const std::string &Description() const { return description_; } + void AddNodes(std::vector nodes); + const std::vector &Nodes() const { return nodes_; } + void AddScopes(const std::vector &scopes) { scopes_.insert(scopes_.end(), scopes.begin(), scopes.end()); } + const std::vector &Scopes() const { return scopes_; } + const std::unordered_map> &GetInputs() const { return inputs_; } + const std::unordered_map> &GetOutputs() const { return outputs_; } + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + bool FindNodes(const std::string &node_name) const; + bool FindScopes(const std::string &scope_name) const; + + InnerNodeInfo *AddInnerNode(const string &name, const string &type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + FusionInnerNodesInfo GetInnerNodesInfo(); + ge::graphStatus CheckInnerNodesInfo(); + + private: + std::string name_; + std::string type_; + std::string description_; + std::vector scopes_; + std::vector nodes_; + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::vector inner_node_infos_; +}; + +class ScopeTree::ScopeTreeImpl { + public: + ScopeTreeImpl() : root_(nullptr) {} + ScopeTreeImpl(const ScopeTreeImpl &) = delete; + ScopeTreeImpl &operator=(const ScopeTreeImpl &) = delete; + Status Init(); + ~ScopeTreeImpl(); + + void AddNodeToScope(ge::OperatorPtr &node_def); + const std::vector &GetAllScopes() const { return scopes_; } + const Scope *Root() const { return root_; } + + private: + std::vector SplitNodeName(const std::string &node_name, char delim) const; + Scope *root_; + std::vector scopes_; +}; + +struct ScopeFusionOpInfo { + std::string node_name; + std::string fusion_node_name; + std::string fusion_op_type; + std::string description; + bool scope_pass = true; +}; + +class ScopeGraph::ScopeGraphImpl { + public: + ScopeGraphImpl() : scope_tree_(nullptr) {} + ScopeGraphImpl(const ScopeGraphImpl &) = delete; + ScopeGraphImpl &operator=(const ScopeGraphImpl &) = delete; + Status Init(); + ~ScopeGraphImpl(); + + const ScopeTree *GetScopeTree() const { return scope_tree_; } + void BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); + void AddFusionScopesResult(FusionScopesResult *result); + const std::unordered_map &FusionScopesResults() const { return fusion_results_; } + FusionScopesResult *GetFusionScopesResults(const domi::tensorflow::NodeDef *node_def) const; + FusionScopesResult *GetFusionScopesResults(const string &node_name) const; + const std::unordered_map &GetNodesMap() const { return nodes_map_; } + bool IsFusionOpChild(const std::string &node_name, std::vector &info_list); + bool FusionOpChildIgnore(const ScopeFusionOpInfo &info); + bool IsFusionOp(const domi::tensorflow::NodeDef *node_def); + Status GetInputOrOutputIndex(const ScopeFusionOpInfo &info, int32_t old_index, bool input, int32_t &new_index); + + private: + std::vector GetFusionResultInputOrOutput(const ScopeFusionOpInfo &info, + bool input); // input:true,output:false + void CheckScopesResult(FusionScopesResult *fusion_node); + std::unordered_map fusion_results_; + std::unordered_map nodes_map_; + ScopeTree *scope_tree_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H_ \ No newline at end of file diff --git a/metadef/inc/register/scope/scope_pass_impl.h b/metadef/inc/register/scope/scope_pass_impl.h new file mode 100644 index 00000000..ef2d97c6 --- /dev/null +++ b/metadef/inc/register/scope/scope_pass_impl.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef REGISTER_SCOPE_SCOPE_PASS_IMPL_H_ +#define REGISTER_SCOPE_SCOPE_PASS_IMPL_H_ + +#include "external/register/scope/scope_fusion_pass_register.h" + +namespace ge { +class ScopesResult::ScopesResultImpl { + public: + void SetScopes(const std::vector &scopes) { scopes_ = scopes; } + const std::vector &GetScopes() const { return scopes_; } + void SetNodes(const std::vector &nodes) { nodes_ = nodes; } + const std::vector &GetNodes() const { return nodes_; } + + private: + std::vector scopes_; // multiple scopes + std::vector nodes_; // op outside of scope +}; + +class ScopeBasePass::ScopeBasePassImpl { + public: + ScopeBasePassImpl(ScopeBasePass *parent) : parent_(parent) {} + virtual ~ScopeBasePassImpl(); + + Status Run(std::shared_ptr &scope_graph); + + private: + Status AddFusionScopesResultToScopeGraph(std::shared_ptr &scope_graph, + std::vector &scope_results); + // Match rules one by one, support multiple sets of matching rules, and finally output a single scope + // Note: This function does not have to be rewritten. + // In order to match the fusion rules designed by you better, + // you can implement your specific versions separately. + bool MatchAllBatches(const ScopeTree *scope_tree, std::vector &results); + + bool MatchOneBatch(const ScopeTree *scope_tree, const std::vector &patternlist, + std::vector &results); + bool MatchOneScope(const ScopePattern *pattern, Scope *scope, std::vector &results); + Status PrintFusionScopeInfo(std::shared_ptr &scope_graph); + + private: + std::vector patterns_; + ScopeBasePass *parent_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_PASS_IMPL_H_ \ No newline at end of file diff --git a/metadef/inc/register/scope/scope_pass_registry_impl.h b/metadef/inc/register/scope/scope_pass_registry_impl.h new file mode 100644 index 00000000..9e68dba0 --- /dev/null +++ b/metadef/inc/register/scope/scope_pass_registry_impl.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef REGISTER_SCOPE_SCOPE_REGISTRY_IMPL_H_ +#define REGISTER_SCOPE_SCOPE_REGISTRY_IMPL_H_ + +#include "external/register/scope/scope_fusion_pass_register.h" +#include + +namespace ge { +struct CreatePassFnPack; +class ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl { + public: + void RegisterScopeFusionPass(const std::string &pass_name, ScopeFusionPassRegistry::CreateFn create_fn, + bool is_general); + ScopeFusionPassRegistry::CreateFn GetCreateFn(const std::string &pass_name); + std::unique_ptr CreateScopeFusionPass(const std::string &pass_name); + std::vector GetAllRegisteredPasses(); + bool SetPassEnableFlag(const std::string pass_name, const bool flag); + + private: + std::mutex mu_; + std::vector pass_names_; // In the order of user registration + std::map create_fn_packs_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_REGISTRY_IMPL_H_ \ No newline at end of file diff --git a/metadef/inc/register/scope/scope_pattern_impl.h b/metadef/inc/register/scope/scope_pattern_impl.h new file mode 100644 index 00000000..7f0445ef --- /dev/null +++ b/metadef/inc/register/scope/scope_pattern_impl.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_ +#define REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_ + +#include "external/register/scope/scope_fusion_pass_register.h" + +namespace ge { +class ScopeAttrValue::ScopeAttrValueImpl { + public: + ScopeAttrValueImpl() : int_value_(0), float_value_(0.0), string_value_(""), bool_value_(false) {} + ~ScopeAttrValueImpl() {} + + void SetIntValue(const int64_t &value) { int_value_ = value; } + void SetFloatValue(const float &value) { float_value_ = value; } + void SetStringValue(const std::string &value) { string_value_ = value; } + void SetBoolValue(const bool &value) { bool_value_ = value; } + const int64_t &GetIntValue() const { return int_value_; } + const float &GetFloatValue() const { return float_value_; } + const std::string &GetStrValue() const { return string_value_; } + const bool &GetBoolValue() const { return bool_value_; } + + private: + int64_t int_value_; + float float_value_; + std::string string_value_; + bool bool_value_; +}; + +class NodeOpTypeFeature::NodeOpTypeFeatureImpl : ScopeBaseFeature { + public: + NodeOpTypeFeatureImpl(std::string nodeType, int num, int step = 0) + : node_type_(nodeType), num_(num), step_(step) {} + ~NodeOpTypeFeatureImpl() {} + bool Match(const Scope *scope) override; + + public: + std::string node_type_; // Node type + int num_; // Node number + int step_; // step +}; + +class NodeAttrFeature::NodeAttrFeatureImpl : ScopeBaseFeature { + public: + NodeAttrFeatureImpl(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value) + : node_type_(nodeType), attr_name_(attr_name), datatype_(datatype), attr_value_(attr_value) {} + ~NodeAttrFeatureImpl() {} + bool Match(const Scope *scope) override; + + public: + std::string node_type_; // Node type + std::string attr_name_; // attribute name + ge::DataType datatype_; // datatype + ScopeAttrValue attr_value_; // AttrValue +}; + +class ScopeFeature::ScopeFeatureImpl : ScopeBaseFeature { + public: + ScopeFeatureImpl(std::string sub_type, int32_t num, std::string suffix = "", + std::string sub_scope_mask = "", int step = 0) + : sub_type_(sub_type), num_(num), suffix_(suffix), sub_scope_mask_(sub_scope_mask), step_(step) {} + ~ScopeFeatureImpl() {} + bool Match(const Scope *scope) override; + bool SubScopesMatch(const std::vector &scopes); + + public: + std::string sub_type_; + int32_t num_; + std::string suffix_; + std::string sub_scope_mask_; + int step_; +}; + +class ScopePattern::ScopePatternImpl { + public: + ScopePatternImpl() {} + ~ScopePatternImpl() {} + bool Match(const Scope *scope) const; + void SetSubType(const std::string &sub_type); + const std::string &SubType() const { return sub_type_; } + void AddNodeOpTypeFeature(NodeOpTypeFeature &feature); + void AddNodeAttrFeature(NodeAttrFeature &feature); + void AddScopeFeature(ScopeFeature &feature); + + private: + std::string sub_type_; // get Scope sub type + std::vector node_optype_features_; + std::vector node_attr_features_; + std::vector scopes_features_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_ \ No newline at end of file diff --git a/metadef/inc/register/tensor_assign.h b/metadef/inc/register/tensor_assign.h new file mode 100644 index 00000000..57a37f6c --- /dev/null +++ b/metadef/inc/register/tensor_assign.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSOR_ASSIGN_H_ +#define TENSOR_ASSIGN_H_ + +#include "graph/ge_tensor.h" +#include "proto/tensorflow/tensor.pb.h" + +namespace domi { +using GeTensorPtr = std::shared_ptr; +using Status = uint32_t; +using domi::tensorflow::TensorProto; +using google::protobuf::int32; +using google::protobuf::int64; + +class TensorAssign { + public: + static Status SetGeTensor(const TensorProto &tensor, GeTensorPtr &weight); + + static Status SetGeTensorDataType(int64_t dataType, GeTensorPtr &weight); + + static ge::DataType ConvertTensorflowDataType(uint32_t tf_data_type); + + private: + static bool CheckBoolVal(tensorflow::DataType data_type); + + static bool CheckHalfVal(tensorflow::DataType data_type); + + static bool CheckFloatVal(tensorflow::DataType data_type); + + static bool CheckDoubleVal(tensorflow::DataType data_type); + + static bool CheckComplex64Val(tensorflow::DataType data_type); + + static bool CheckComplex128Val(tensorflow::DataType data_type); + + static bool CheckStringVal(tensorflow::DataType data_type); + + static bool CheckByte(tensorflow::DataType data_type); + + static bool CheckDoubleByte(tensorflow::DataType data_type); + + static bool CheckSignedFourByte(tensorflow::DataType data_type); + + static bool CheckUnsignedFourByte(tensorflow::DataType data_type); + + static bool CheckSignedEightByte(tensorflow::DataType data_type); + + static bool CheckUnsignedEightByte(tensorflow::DataType data_type); + + static Status GetDoubleByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, + GeTensorPtr &weight); + static Status GetByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, + GeTensorPtr &weight); + + static Status GetStringVal(int32_t val_size, const google::protobuf::RepeatedPtrField &val_vector, + int count, GeTensorPtr &weight); + + static void SetGeTensorWeightData(const TensorProto &tensor, int32_t val_size, int count, GeTensorPtr &weight); + + static void SetWeightData(tensorflow::DataType data_type, int count, const std::string &tensor_content, + GeTensorPtr &weight); + + template + static Status GetVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, + GeTensorPtr &weight) { + bool zerosLike = (count != val_size && val_size == 1); + T *addr = new (std::nothrow) T[count](); + GE_CHECK_NOTNULL(addr); + int minCount = (count > val_size) ? val_size : count; + if (!zerosLike) { + for (int32_t i = 0; i < minCount; i++) { + *(addr + i) = val_vector.Get(i); + } + for (int32_t i = minCount; i < count; i++) { + *(addr + i) = val_vector.Get(minCount - 1); + } + } else { + for (int32_t i = 0; i < count; i++) { + *(addr + i) = val_vector.Get(0); + } + } + (void)weight->SetData(reinterpret_cast(addr), count * sizeof(T)); + GE_DELETE_NEW_ARRAY(addr); + return SUCCESS; + } +}; +} // namespace domi +#endif // TENSOR_ASSIGN_H_ diff --git a/metadef/ops/op_imp.cpp b/metadef/ops/op_imp.cpp new file mode 100644 index 00000000..4866a0a5 --- /dev/null +++ b/metadef/ops/op_imp.cpp @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "debug/ge_log.h" +#include "debug/ge_util.h" + +using namespace std; + +namespace ge { + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +BroadCastInfer(const function()>& get_in1_shape, const function()>& get_in2_shape, + const function& outShape)>& set_out_shape) { + auto x1_shape = get_in1_shape(); + auto x2_shape = get_in2_shape(); + vector y_shape; + + if (x1_shape.empty()) { + y_shape = x2_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + if (x2_shape.empty()) { + y_shape = x1_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + + int len_diff = static_cast(x1_shape.size() - x2_shape.size()); + if (len_diff >= 0) { + for (int i = 0; i < len_diff; i++) { + y_shape.push_back(x1_shape[i]); + } + int x2_shape_size = static_cast(x2_shape.size()); + for (int i = 0; i < x2_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); + } + } else { + for (int i = 0; i < -len_diff; i++) { + y_shape.push_back(x2_shape[i]); + } + int x1_shape_size = static_cast(x1_shape.size()); + for (int i = 0; i < x1_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); + } + } + set_out_shape(y_shape); + return GRAPH_SUCCESS; +} + +} // namespace ge diff --git a/metadef/proto/caffe/caffe.proto b/metadef/proto/caffe/caffe.proto new file mode 100644 index 00000000..3f45aae2 --- /dev/null +++ b/metadef/proto/caffe/caffe.proto @@ -0,0 +1,1821 @@ +syntax = "proto2"; + +package domi.caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + optional bytes int8_data = 10; + repeated int32 int32_data = 11 [packed = true]; + repeated uint64 uint64_data = 12 [packed = true]; + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; + optional ROIPoolingParameter roi_pooling_param = 161; + optional YoloParameter yolo_param = 199; + optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; + optional ProposalParameter proposal_param = 201; + optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; + optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; + optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; + optional QuantParameter quant_param = 208; + optional CondTakeParameter condtake_param = 233; + optional MatrixInverseParameter matrix_inverse_param = 210; + optional WarpPerspectiveParameter warp_perspective_param = 234; + optional BatchMatMulParameter batch_matmul_param = 235; + optional SpatialTransformerParameter st_param = 5000; + optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + optional bool bias_from_blob = 4 [default = true]; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; + optional bool scale_from_blob = 6 [default = true]; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + QUANT = 208; + DEQUANT = 209; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; + optional float objectness_score = 22; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; + optional int32 roi_end_mode = 5 [default = 0]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; // Category of classification + optional uint32 coords = 2 [default = 4]; // Coordinates of box + optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + repeated int32 axis = 1; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional float scale = 1[default = 1]; + optional int32 stride = 2[default = 2]; + optional int32 stride_h = 3[default = 2]; + optional int32 stride_w = 4[default=2]; +} +message ROIPoolingParameter { + required int32 pooled_h = 1; + required int32 pooled_w = 2; + optional float spatial_scale = 3 [default=0.0625]; + optional float spatial_scale_h = 4; + optional float spatial_scale_w = 5; +} + +message YoloParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 coords = 2 [default = 4]; + optional int32 classes = 3 [default = 80]; + optional string yolo_version = 4 [default = "V3"]; + optional bool softmax = 5 [default = false]; + optional bool background = 6 [default = false]; + optional bool softmaxtree = 7 [default = false]; +} + +message YoloV3DetectionOutputParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; +} + +message YoloV3DetectionOutputV2Parameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; + optional int32 out_box_dim = 15 [default = 3]; +} + +message ProposalParameter { + optional float feat_stride = 1 [default = 16]; + optional float base_size = 2 [default = 16]; + optional float min_size = 3 [default = 16]; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6 [default = 3000]; + optional int32 post_nms_topn = 7 [default = 304]; + optional float iou_threshold = 8 [default = 0.7]; + optional bool output_actual_rois_num = 9 [default = false]; +} + +message FSRDetectionOutputParameter { + required int32 num_classes = 1; + required float score_threshold = 2; + required float iou_threshold = 3; + optional int32 batch_rois = 4 [default = 1]; +} + +message SSDDetectionOutputParameter { + required int32 num_classes= 1 [default = 2]; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional float iou_threshold = 4 [default = 0.3]; + optional int32 top_k = 5 [default = 200]; + optional float eta = 6 [default = 1.0]; + optional bool variance_encoded_in_target = 7 [default = false]; + optional int32 code_type = 8 [default = 1]; + optional int32 keep_top_k = 9 [default = -1]; + optional float confidence_threshold = 10 [default = 0.0]; +} +message YoloV2DetectionOutputParameter { + optional int32 boxes = 1 [default = 5]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases = 9; + optional int32 coords = 10 [default = 4]; + optional bool resize_origin_img_to_net = 11 [default = false]; +} + +message QuantParameter { + optional float scale = 2; + optional bytes offset = 3; +} + +message BatchMatMulParameter{ + optional bool adj_x1 = 1 [default = false]; + optional bool adj_x2 = 2 [default = false]; +} + +message CondTakeParameter { + required string mode = 1; + required float val = 2; + optional float eps = 3 [default = 1e-06]; +} + +message MatrixInverseParameter { + optional bool adjoint = 1 [default = false]; +} + +message WarpPerspectiveParameter { + required int32 out_height = 1; + required int32 out_width = 2; + optional float constant = 3; + optional string border_type = 4 [default = 'BORDER_CONSTANT']; +} + +message SpatialTransformerParameter { + // How to use the parameter passed by localisation network + optional string transform_type = 1 [default = "affine"]; + // What is the sampling technique + optional string sampler_type = 2 [default = "bilinear"]; + + // If not set,stay same with the input dimension H and W + optional int32 output_H = 3; + optional int32 output_W = 4; + // If false, only compute dTheta, DO NOT compute dU + optional bool to_compute_dU = 5 [default = true]; + + // The default value for some parameters + optional double theta_1_1 = 6; + optional double theta_1_2 = 7; + optional double theta_1_3 = 8; + optional double theta_2_1 = 9; + optional double theta_2_2 = 10; + optional double theta_2_3 = 11; +} diff --git a/metadef/proto/dump_task.proto b/metadef/proto/dump_task.proto new file mode 100644 index 00000000..b1e346cd --- /dev/null +++ b/metadef/proto/dump_task.proto @@ -0,0 +1,111 @@ +syntax = "proto3"; +package toolkit.dumpdata; + +enum OutputDataType { + DT_UNDEFINED = 0; + DT_FLOAT = 1; + DT_FLOAT16 = 2; + DT_INT8 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_UINT16 = 6; + DT_INT32 = 7; + DT_INT64 = 8; + DT_UINT32 = 9; + DT_UINT64 = 10; + DT_BOOL = 11; + DT_DOUBLE = 12; + DT_STRING = 13; + DT_DUAL_SUB_INT8 = 14; + DT_DUAL_SUB_UINT8 = 15; + DT_COMPLEX64 = 16; + DT_COMPLEX128 = 17; + DT_QINT8 = 18; + DT_QINT16 = 19; + DT_QINT32 = 20; + DT_QUINT8 = 21; + DT_QUINT16 = 22; + DT_RESOURCE = 23; + DT_STRING_REF = 24; + DT_DUAL = 25; +} + +enum OutputFormat { + FORMAT_NCHW = 0; + FORMAT_NHWC = 1; + FORMAT_ND = 2; + FORMAT_NC1HWC0 = 3; + FORMAT_FRACTAL_Z = 4; + FORMAT_NC1C0HWPAD = 5; + FORMAT_NHWC1C0 = 6; + FORMAT_FSR_NCHW = 7; + FORMAT_FRACTAL_DECONV = 8; + FORMAT_C1HWNC0 = 9; + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; + FORMAT_NC1HWC0_C04 = 12; + FORMAT_FRACTAL_Z_C04 = 13; + FORMAT_CHWN = 14; + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; + FORMAT_HWCN = 16; + FORMAT_NC1KHKWHWC0 = 17; + FORMAT_BN_WEIGHT = 18; + FORMAT_FILTER_HWCK = 19; + FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; + FORMAT_HASHTABLE_LOOKUP_KEYS = 21; + FORMAT_HASHTABLE_LOOKUP_VALUE = 22; + FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; + FORMAT_HASHTABLE_LOOKUP_HITS=24; + FORMAT_C1HWNCoC0 = 25; + FORMAT_MD = 26; + FORMAT_NDHWC = 27; + FORMAT_FRACTAL_ZZ = 28; + FORMAT_FRACTAL_NZ = 29; + FORMAT_RESERVED = 30; +} + +message OriginalOp { + string name = 1; + uint32 output_index = 2; + OutputDataType data_type = 3; + OutputFormat format = 4; +} + +message Shape { + repeated uint64 dim = 1; +} + +message OpOutput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + OriginalOp original_op = 4; // the original op corresponding to the output + bytes data = 5; + uint64 size = 6; +} + +message OpInput { + OutputDataType data_type = 1; + OutputFormat format = 2; + Shape shape = 3; + bytes data = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + bytes data = 2; + uint64 size = 3; +} + +message DumpData{ + string version = 1; + uint64 dump_time = 2; + repeated OpOutput output = 3; + repeated OpInput input = 4; + repeated OpBuffer buffer = 5; +} diff --git a/metadef/proto/fusion_model.proto b/metadef/proto/fusion_model.proto new file mode 100644 index 00000000..c92c5581 --- /dev/null +++ b/metadef/proto/fusion_model.proto @@ -0,0 +1,21 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +import "om.proto"; + +package domi; + +message FusionModelDef { + string version = 1; + repeated OpDef fusion_op = 2; +} \ No newline at end of file diff --git a/metadef/proto/fwk_adapter.proto b/metadef/proto/fwk_adapter.proto new file mode 100644 index 00000000..9335c926 --- /dev/null +++ b/metadef/proto/fwk_adapter.proto @@ -0,0 +1,37 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package aicpu.FWKAdapter; +option cc_enable_arenas = true; + + +// Defines an struct for input and output. +message TensorDataInfo { + + // value DataType + uint32 dtype = 1; + + // shape dim + repeated int64 dim = 2; + + // data point addr + int64 data_addr = 3; +} + +message KernelRunParam { + // input + repeated TensorDataInfo input = 1; + // output + repeated TensorDataInfo output = 2; +} + diff --git a/metadef/proto/ge_api.proto b/metadef/proto/ge_api.proto new file mode 100644 index 00000000..331c5aea --- /dev/null +++ b/metadef/proto/ge_api.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; +package ge.api_pb; + +import "ge_ir.proto"; + +// GE initialize +message GEInitialize { + map options = 1; +}; + +// initialize response +message GEInitializeResponse { + uint32 status = 1; + uint32 clientId = 2; +}; + +// GE finalize +message GEFinalize { + bool final = 1; + uint32 clientId = 2; +}; + +message GEFinalizeResponse { + uint32 status = 1; +}; + +// GE Session +message CreateSession{ + map options = 1; +}; + +message CreateSessionResponse { + uint32 status = 1; + uint64 sessionId = 2; +}; + +//GE AddGraph +//model serialize :: serializegraph +message SessionAddGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + ge.proto.GraphDef graph = 3; +}; + +message SessionAddGraphResponse { + uint32 status = 1; +}; + +//GE SessionRemoveGraph +message SessionRemoveGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; +}; + +message SessionRemoveGraphResponse { + uint32 status = 1; +}; + +message SessionRunGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; +}; + +message SessionBuildGraph{ + uint32 graphId = 1; + uint64 sessionId = 2; + repeated ge.proto.TensorDef tensor = 3; + string savePath = 4; +}; + +message SessionRunGraphResponse { + uint32 status = 1; + repeated ge.proto.TensorDef tensor = 2; +}; + +message SessionBuildGraphResponse { + uint32 status = 1; +}; + +message DestroySession{ + bool final = 1; + uint64 sessionId = 2; +}; + +message DestroySessionResponse { + uint32 status = 1; +}; diff --git a/metadef/proto/ge_ir.proto b/metadef/proto/ge_ir.proto new file mode 100644 index 00000000..e7bfe0cb --- /dev/null +++ b/metadef/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/metadef/proto/insert_op.proto b/metadef/proto/insert_op.proto new file mode 100644 index 00000000..bf918b20 --- /dev/null +++ b/metadef/proto/insert_op.proto @@ -0,0 +1,139 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // related_input_name is optional and the top name of data node which inserts aipp + string related_input_name = 6; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/metadef/proto/om.proto b/metadef/proto/om.proto new file mode 100644 index 00000000..e15e5f80 --- /dev/null +++ b/metadef/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/metadef/proto/onnx/ge_onnx.proto b/metadef/proto/onnx/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/proto/onnx/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/proto/op_mapping_info.proto b/metadef/proto/op_mapping_info.proto new file mode 100644 index 00000000..e23b7ebe --- /dev/null +++ b/metadef/proto/op_mapping_info.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +} + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +enum BufferType { + L1 = 0; +} + +message OpBuffer { + BufferType buffer_type = 1; + uint64 address = 2; + uint64 size = 3; +} + +message Op { + string op_name = 1; + string op_type = 2; +} + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; + repeated OpBuffer buffer = 7; +} + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +} \ No newline at end of file diff --git a/metadef/proto/optimizer_priority.proto b/metadef/proto/optimizer_priority.proto new file mode 100644 index 00000000..769619cf --- /dev/null +++ b/metadef/proto/optimizer_priority.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; +package ge.optimizers; + +// Default: GE>FE>AICPU +message Priority{ + repeated string optimizer = 1; +} \ No newline at end of file diff --git a/metadef/proto/proto_inner/ge_onnx.proto b/metadef/proto/proto_inner/ge_onnx.proto new file mode 100644 index 00000000..4cd77f3a --- /dev/null +++ b/metadef/proto/proto_inner/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/metadef/proto/task.proto b/metadef/proto/task.proto new file mode 100644 index 00000000..d0c09840 --- /dev/null +++ b/metadef/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/metadef/proto/tensorflow/attr_value.proto b/metadef/proto/tensorflow/attr_value.proto new file mode 100644 index 00000000..1cc67d62 --- /dev/null +++ b/metadef/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/metadef/proto/tensorflow/function.proto b/metadef/proto/tensorflow/function.proto new file mode 100644 index 00000000..075897c6 --- /dev/null +++ b/metadef/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/metadef/proto/tensorflow/graph.proto b/metadef/proto/tensorflow/graph.proto new file mode 100644 index 00000000..d639a7d6 --- /dev/null +++ b/metadef/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/metadef/proto/tensorflow/graph_library.proto b/metadef/proto/tensorflow/graph_library.proto new file mode 100644 index 00000000..e393d38d --- /dev/null +++ b/metadef/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/metadef/proto/tensorflow/node_def.proto b/metadef/proto/tensorflow/node_def.proto new file mode 100644 index 00000000..b9bc97ee --- /dev/null +++ b/metadef/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/metadef/proto/tensorflow/op_def.proto b/metadef/proto/tensorflow/op_def.proto new file mode 100644 index 00000000..3485d045 --- /dev/null +++ b/metadef/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/metadef/proto/tensorflow/resource_handle.proto b/metadef/proto/tensorflow/resource_handle.proto new file mode 100644 index 00000000..a3452351 --- /dev/null +++ b/metadef/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/metadef/proto/tensorflow/tensor.proto b/metadef/proto/tensorflow/tensor.proto new file mode 100644 index 00000000..d0a4d024 --- /dev/null +++ b/metadef/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/metadef/proto/tensorflow/tensor_shape.proto b/metadef/proto/tensorflow/tensor_shape.proto new file mode 100644 index 00000000..4225a2e3 --- /dev/null +++ b/metadef/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/metadef/proto/tensorflow/types.proto b/metadef/proto/tensorflow/types.proto new file mode 100644 index 00000000..ba7a72b3 --- /dev/null +++ b/metadef/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/metadef/proto/tensorflow/versions.proto b/metadef/proto/tensorflow/versions.proto new file mode 100644 index 00000000..48061218 --- /dev/null +++ b/metadef/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/metadef/register/CMakeLists.txt b/metadef/register/CMakeLists.txt new file mode 100644 index 00000000..5b5e1559 --- /dev/null +++ b/metadef/register/CMakeLists.txt @@ -0,0 +1,241 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/tensorflow/attr_value.proto" + "${METADEF_DIR}/proto/tensorflow/function.proto" + "${METADEF_DIR}/proto/tensorflow/graph.proto" + "${METADEF_DIR}/proto/tensorflow/node_def.proto" + "${METADEF_DIR}/proto/tensorflow/op_def.proto" + "${METADEF_DIR}/proto/tensorflow/resource_handle.proto" + "${METADEF_DIR}/proto/tensorflow/tensor.proto" + "${METADEF_DIR}/proto/tensorflow/tensor_shape.proto" + "${METADEF_DIR}/proto/tensorflow/types.proto" + "${METADEF_DIR}/proto/tensorflow/versions.proto" + "${METADEF_DIR}/proto/task.proto" + "${METADEF_DIR}/proto/om.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "register.cpp" + "ops_kernel_builder_registry.cc" + "graph_optimizer/graph_fusion/graph_fusion_pass_base.cc" + "graph_optimizer/graph_fusion/fusion_pass_registry.cc" + "graph_optimizer/graph_fusion/fusion_pattern.cc" + "graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc" + "graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc" + "graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" + "graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc" + "graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc" + "graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc" + "graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc" + "register_format_transfer.cc" + "op_kernel_registry.cpp" + "auto_mapping_util.cpp" + "host_cpu_context.cc" + "tensor_assign.cpp" + "infer_data_slice_registry.cc" + "scope/scope_graph.cc" + "scope/scope_pass.cc" + "scope/scope_pattern.cc" + "scope/scope_util.cc" + "scope/scope_pass_registry.cc" +) + +############ libregister.so ############ +add_library(register SHARED ${SRC_LIST} ${PROTO_SRCS}) + +target_compile_options(register PRIVATE + $<$,$>: -Wno-deprecated-declarations> +) + +target_compile_definitions(register PRIVATE + google=ascend_private + $<$:ONLY_COMPILE_OPEN_SRC> +) + +target_include_directories(register PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/.. + ${METADEF_DIR}/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${METADEF_DIR}/../inc + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + #### temp in parser #### + ${METADEF_DIR}/../../graphengine/inc + ${METADEF_DIR}/../../graphengine/inc/framework + ${METADEF_DIR}/../../graphengine/inc/external + ${METADEF_DIR}/../../inc + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${METADEF_DIR}/../third_party/fwkacllib/inc + #### blue independent compile ##### + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/ge/inc + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(register PRIVATE + $ + -Wl,--whole-archive + op_tiling_o2 + -Wl,--no-whole-archive + -Wl,--no-as-needed + ascend_protobuf + c_sec + slog + graph + -Wl,--as-needed + json +) + +############ libregister.a ############ +add_library(register_static STATIC ${SRC_LIST} ${PROTO_SRCS} + "op_tiling.cpp" + "op_tiling_registry.cpp" +) + +target_compile_options(register_static PRIVATE + $<$,$>: -Wno-deprecated-declarations> + $<$:/utf-8> + $<$,$>:/MTd> + $<$,$>:/MT> +) + +target_compile_definitions(register_static PRIVATE + google=ascend_private + $<$:ONLY_COMPILE_OPEN_SRC> + $,OS_TYPE=WIN,OS_TYPE=0> + $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + LOG_CPP +) +target_include_directories(register_static PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/.. + ${METADEF_DIR}/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${METADEF_DIR}/../inc + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + #### temp in parser #### + ${METADEF_DIR}/../../graphengine/inc + ${METADEF_DIR}/../../graphengine/inc/framework + ${METADEF_DIR}/../../graphengine/inc/external + ${METADEF_DIR}/../../inc + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + #### blue independent compile ##### + ${METADEF_DIR}/../third_party/fwkacllib/inc + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/inc/framework + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(register_static PRIVATE + ascend_protobuf + c_sec + json + $ +) + +set_target_properties(register_static PROPERTIES + WINDOWS_EXPORT_ALL_SYMBOLS TRUE + OUTPUT_NAME $,libregister,register> +) + +############ libop_tiling_o2.a ############ +add_library(op_tiling_o2 STATIC + "op_tiling.cpp" + "op_tiling_registry.cpp" +) + +target_include_directories(op_tiling_o2 PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${METADEF_DIR}/inc/graph + ${METADEF_DIR}/.. + ${METADEF_DIR}/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${METADEF_DIR}/../inc + #### temp independent #### + ${METADEF_DIR}/../graphengine/inc + ${METADEF_DIR}/../graphengine/inc/framework + ${METADEF_DIR}/../graphengine/inc/external + #### temp in ge #### + ${METADEF_DIR}/../inc + ${METADEF_DIR}/../inc/framework + ${METADEF_DIR}/../inc/external + #### temp in parser #### + ${METADEF_DIR}/../../graphengine/inc + ${METADEF_DIR}/../../graphengine/inc/framework + ${METADEF_DIR}/../../graphengine/inc/external + ${METADEF_DIR}/../../inc + #### blue zone #### + ${ASCEND_DIR}/driver/include + ${ASCEND_DIR}/fwkacllib/include + ${METADEF_DIR}/../third_party/fwkacllib/inc + #### blue independent compile #### + ${METADEF_DIR}/third_party/graphengine/inc + ${METADEF_DIR}/third_party/graphengine/inc/external + ${METADEF_DIR}/third_party/fwkacllib/inc +) + +target_compile_options(op_tiling_o2 PRIVATE + -O2 + $<$,$>: -Wno-deprecated-declarations> +) + +target_compile_definitions(op_tiling_o2 PRIVATE + LOG_CPP +) + +target_link_libraries(op_tiling_o2 PRIVATE + $ + json + c_sec +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS register register_static OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} + ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/metadef/register/auto_mapping_util.cpp b/metadef/register/auto_mapping_util.cpp new file mode 100644 index 00000000..b2936a63 --- /dev/null +++ b/metadef/register/auto_mapping_util.cpp @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "register/auto_mapping_util.h" +#include "graph/debug/ge_util.h" + +namespace ge { + +// Convert tensorflow property to ge property +bool AutoMappingUtil::FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, const string &attr_name, + domi::tensorflow::AttrValue &attr_value) { + GE_CHECK_NOTNULL(nodeDef); + const google::protobuf::Map &attr = nodeDef->attr(); + const google::protobuf::Map::const_iterator it = attr.find(attr_name); + if (it != attr.end()) { + attr_value = it->second; + return true; + } + return false; +} + +// Get the attribute shape of tensorflow +void AutoMappingUtil::ConvertShape(const domi::tensorflow::TensorShapeProto &shape, + vector& shape_dims) { + shape_dims.clear(); + if (!shape.unknown_rank()) { + for (auto &dim : shape.dim()) { + shape_dims.push_back(dim.size()); + } + } else { + shape_dims = ge::UNKNOWN_SHAPE; + } +} + +graphStatus AutoMappingUtil::ConvertTensor(const domi::tensorflow::TensorProto &tensor, ge::GeTensorPtr &weight) { + weight = ComGraphMakeShared(); + if (weight == nullptr) { + GE_LOGE("Weight is nullptr."); + return GRAPH_FAILED; + } + domi::tensorflow::DataType tf_data_type = tensor.dtype(); + ge::DataType ge_data_type = domi::TensorAssign::ConvertTensorflowDataType(tf_data_type); + if (domi::TensorAssign::SetGeTensorDataType(ge_data_type, weight) != domi::SUCCESS) { + GE_LOGE("Set Ge tensor data type failed."); + return GRAPH_FAILED; + } + if (domi::TensorAssign::SetGeTensor(tensor, weight) != domi::SUCCESS) { + GE_LOGE("Set Ge tensor failed."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +void AutoMappingUtil::ConvertTensorList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec) { + vec.clear(); + for (auto &tensor : list.tensor()) { + ge::GeTensorPtr ge_tensor = nullptr; + graphStatus ret = ConvertTensor(tensor, ge_tensor); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Convert tensor failed."); + return; + } + vec.push_back(ge_tensor); + } +} + +void AutoMappingUtil::ConvertFunc(const domi::tensorflow::NameAttrList& tf_func, + ge::GeAttrValue::NAMED_ATTRS& ge_func) { + ge_func.SetName(tf_func.name()); + auto& attrs = tf_func.attr(); + for (auto &item : attrs) { + ConvertValue(item.first, item.second, ge_func); + } +} + +void AutoMappingUtil::ConvertDataTypeList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec) { + vec.clear(); + for (auto &e : list.type()) { + ge::DataType ge_data_type = domi::TensorAssign::ConvertTensorflowDataType(static_cast(e)); + vec.push_back(ge_data_type); + } +} + +void AutoMappingUtil::ConvertShapeList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector> &vec) { + vec.clear(); + for (const auto &e : list.shape()) { + vector shape_dims; + ConvertShape(e, shape_dims); + vec.push_back(shape_dims); + } +} + +void AutoMappingUtil::ConvertFuncList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec) { + vec.clear(); + for (const auto &e : list.func()) { + ge::GeAttrValue::NAMED_ATTRS func; + ConvertFunc(e, func); + vec.push_back(func); + } +} + +} // namespace domi diff --git a/metadef/register/auto_mapping_util.h b/metadef/register/auto_mapping_util.h new file mode 100644 index 00000000..978d6b44 --- /dev/null +++ b/metadef/register/auto_mapping_util.h @@ -0,0 +1,209 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef COMMON_AUTO_MAPPING_UTIL_H_ +#define COMMON_AUTO_MAPPING_UTIL_H_ + +#include "framework/common/debug/ge_log.h" +#include "proto/tensorflow/attr_value.pb.h" +#include "proto/tensorflow/node_def.pb.h" +#include "graph/ge_tensor.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/debug/ge_log.h" +#include "register/register_error_codes.h" +#include "register/tensor_assign.h" + +namespace ge { + +class AutoMappingUtil { +public: + static bool FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, const string &attr_name, + domi::tensorflow::AttrValue &attr_value); + static void ConvertShape(const domi::tensorflow::TensorShapeProto &shape, vector& shape_dims); + static graphStatus ConvertTensor(const domi::tensorflow::TensorProto &tensor, ge::GeTensorPtr &weight); + static void ConvertFunc(const domi::tensorflow::NameAttrList& tf_func, ge::GeAttrValue::NAMED_ATTRS& ge_func); + + static void ConvertDataTypeList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec); + static void ConvertShapeList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector> &vec); + static void ConvertTensorList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec); + static void ConvertFuncList(const domi::tensorflow::AttrValue_ListValue &list, + std::vector &vec); + + // Get the attribute list list of tensorflow and save it to obj according to the key + template + static void ConvertList(const std::string &key, const domi::tensorflow::AttrValue &value, T &obj) { + const domi::tensorflow::AttrValue_ListValue &list = value.list(); + if (list.s_size() > 0) { + vector vec; + for (auto e : list.s()) { + vec.push_back(e); + } + (void)ge::AttrUtils::SetListStr(obj, key, vec); + } else if (list.i_size() > 0) { + vector vec; + for (auto e : list.i()) { + vec.push_back(e); + } + (void)ge::AttrUtils::SetListInt(obj, key, vec); + } else if (list.f_size() > 0) { + vector vec; + for (auto e : list.f()) { + vec.push_back(e); + } + (void)ge::AttrUtils::SetListFloat(obj, key, vec); + } else if (list.b_size() > 0) { + vector vec; + for (auto e : list.b()) { + vec.push_back(e); + } + (void)ge::AttrUtils::SetListBool(obj, key, vec); + } else if (list.type_size() > 0) { + vector vec; + ConvertDataTypeList(list, vec); + (void)ge::AttrUtils::SetListDataType(obj, key, vec); + } else if (list.shape_size() > 0) { + vector> shape_dims_vec; + ConvertShapeList(list, shape_dims_vec); + (void)ge::AttrUtils::SetListListInt(obj, key, shape_dims_vec); + } else if (list.tensor_size() > 0) { + vector vec; + ConvertTensorList(list, vec); + (void)ge::AttrUtils::SetListTensor(obj, key, vec); + } else if (list.func_size() > 0) { + vector vec; + ConvertFuncList(list, vec); + (void)ge::AttrUtils::SetListNamedAttrs(obj, key, vec); + } else { + GELOGD("The list has no value, key is %s.", key.c_str()); + } + } + + // According to the property type of tensorflow, set it to the corresponding property of obj + template + static void ConvertValue(const std::string &key, const domi::tensorflow::AttrValue &value, T &obj) { + switch (value.value_case()) { + case domi::tensorflow::AttrValue::kS: + (void)ge::AttrUtils::SetStr(obj, key, value.s()); + break; + case domi::tensorflow::AttrValue::kI: + (void)ge::AttrUtils::SetInt(obj, key, static_cast(value.i())); + break; + case domi::tensorflow::AttrValue::kF: + (void)ge::AttrUtils::SetFloat(obj, key, static_cast(value.f())); + break; + case domi::tensorflow::AttrValue::kB: + (void)ge::AttrUtils::SetBool(obj, key, static_cast(value.b())); + break; + case domi::tensorflow::AttrValue::kType: { + ge::DataType ge_data_type = domi::TensorAssign::ConvertTensorflowDataType(static_cast(value.type())); + (void)ge::AttrUtils::SetDataType(obj, key, ge_data_type); + break; + } + case domi::tensorflow::AttrValue::kList: + ConvertList(key, value, obj); + break; + case domi::tensorflow::AttrValue::kShape: { + vector shape_dims; + ConvertShape(value.shape(), shape_dims); + (void)ge::AttrUtils::SetListInt(obj, key, shape_dims); + break; + } + case domi::tensorflow::AttrValue::kTensor: { + ge::GeTensorPtr ge_tensor = nullptr; + graphStatus ret = ConvertTensor(value.tensor(), ge_tensor); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Convert ge tensor failed, key is %s.", key.c_str()); + return; + } + (void)ge::AttrUtils::SetTensor(obj, key, ge_tensor); + break; + } + case domi::tensorflow::AttrValue::kFunc: { + ge::GeAttrValue::NAMED_ATTRS func; + ConvertFunc(value.func(), func); + (void)ge::AttrUtils::SetNamedAttrs(obj, key, func); + break; + } + case domi::tensorflow::AttrValue::kPlaceholder: + (void)ge::AttrUtils::SetStr(obj, key, value.placeholder()); + break; + case domi::tensorflow::AttrValue::VALUE_NOT_SET: + GELOGD("the attr value of %s is not set.", key.c_str()); + break; + default: + GE_LOGE("the attr value type(%d) is invalid.", static_cast(value.value_case())); + break; + } + } + +template +static void CopyAttrValue(const std::string &key, const ge::GeAttrValue &value, T &obj_src, T &obj) { + GeAttrValue::ValueType value_type = value.GetValueType(); + bool is_one_type = value_type == GeAttrValue::VT_STRING || value_type == GeAttrValue::VT_INT || + value_type == GeAttrValue::VT_FLOAT || value_type == GeAttrValue::VT_BOOL || + value_type == GeAttrValue::VT_TENSOR || value_type == GeAttrValue::VT_NAMED_ATTRS || + value_type == GeAttrValue::VT_DATA_TYPE; + if (is_one_type) { + switch (value_type) { +#define CASE_ATTR_VALUE_TYPE(GeValueType, ValueType, FuncName) \ + case GeAttrValue::VT_##GeValueType: { \ + ValueType value; \ + (void) ge::AttrUtils::Get##FuncName(obj_src, key, value); \ + (void) ge::AttrUtils::Set##FuncName(obj, key, value); \ + break; \ + } + CASE_ATTR_VALUE_TYPE(STRING, string, Str); + CASE_ATTR_VALUE_TYPE(INT, int64_t, Int); + CASE_ATTR_VALUE_TYPE(FLOAT, float, Float); + CASE_ATTR_VALUE_TYPE(BOOL, bool, Bool); + CASE_ATTR_VALUE_TYPE(TENSOR, ConstGeTensorPtr, Tensor); + CASE_ATTR_VALUE_TYPE(NAMED_ATTRS, ge::GeAttrValue::NAMED_ATTRS, NamedAttrs); + CASE_ATTR_VALUE_TYPE(DATA_TYPE, ge::DataType, DataType); +#undef CASE_ATTR_VALUE_TYPE + default: + break; + } + } else { + switch (value_type) { +#define CASE_ATTR_VALUE_TYPE_LIST(GeValueType, ValueType, FuncName) \ + case GeAttrValue::VT_LIST_##GeValueType: { \ + vector value; \ + (void) ge::AttrUtils::GetList##FuncName(obj_src, key, value); \ + (void) ge::AttrUtils::SetList##FuncName(obj, key, value); \ + break; \ + } + CASE_ATTR_VALUE_TYPE_LIST(STRING, string, Str); + CASE_ATTR_VALUE_TYPE_LIST(INT, int64_t, Int); + CASE_ATTR_VALUE_TYPE_LIST(FLOAT, float, Float); + CASE_ATTR_VALUE_TYPE_LIST(BOOL, bool, Bool); + CASE_ATTR_VALUE_TYPE_LIST(TENSOR, ConstGeTensorPtr, Tensor); + CASE_ATTR_VALUE_TYPE_LIST(NAMED_ATTRS, ge::GeAttrValue::NAMED_ATTRS, NamedAttrs); + CASE_ATTR_VALUE_TYPE_LIST(DATA_TYPE, ge::DataType, DataType); + CASE_ATTR_VALUE_TYPE_LIST(LIST_INT, vector, ListInt); +#undef CASE_ATTR_VALUE_TYPE_LIST + default: + GELOGW("The ge attr value type(%d) is invalid.", static_cast(value_type)); + break; + } + } +} +}; +} // namespace domi +#endif // COMMON_AUTO_MAPPING_UTIL_H_ \ No newline at end of file diff --git a/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc new file mode 100644 index 00000000..621afd74 --- /dev/null +++ b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" +#include +#include +#include + +namespace fe { +BufferFusionPassBase::BufferFusionPassBase() {} + +BufferFusionPassBase::~BufferFusionPassBase() {} + +Status BufferFusionPassBase::GetFusionNodes(const BufferFusionMapping &mapping, + std::vector &fusion_nodes) { + fusion_nodes = GetMatchedNodes(mapping); + return SUCCESS; +} + +std::vector BufferFusionPassBase::GetMatchedNodes(const BufferFusionMapping &mapping) { + std::vector nodes; + for (const auto &item : mapping) { + for (const auto &node : item.second) { + nodes.push_back(node); + } + } + return nodes; +} + +std::vector BufferFusionPassBase::GetMatchedNodesByDescName(const std::string &desc_name, + const BufferFusionMapping &mapping) { + std::vector nodes; + for (const auto &item : mapping) { + const BufferFusionOpDesc *op_desc = item.first; + if (op_desc != nullptr && op_desc->desc_name == desc_name) { + for (const auto &node : item.second) { + nodes.push_back(node); + } + } + } + return nodes; +} + +ge::NodePtr BufferFusionPassBase::GetMatchedHeadNode(const std::vector &matched_nodes) { + for (const auto &node : matched_nodes) { + auto input_nodes = node->GetInDataNodes(); + bool find_flag = false; + for (const auto &in_node : input_nodes) { + // find the node from fuison sub graph + if (std::find(matched_nodes.begin(), matched_nodes.end(), in_node) != matched_nodes.end()) { + find_flag = true; + break; + } + } + if (find_flag == false) { + return node; + } + } + return nullptr; +} + +} // namespace fe diff --git a/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc new file mode 100644 index 00000000..5409cdd3 --- /dev/null +++ b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h" +#include +#include +#include +#include +#include +#include "graph/debug/ge_log.h" + +namespace fe { +class BufferFusionPassRegistry::BufferFusionPassRegistryImpl { + public: + void RegisterPass(const BufferFusionPassType &pass_type, const std::string &pass_name, + BufferFusionPassRegistry::CreateFn create_fn) { + std::lock_guard lock(mu_); + auto iter = create_fns_.find(pass_type); + if (iter != create_fns_.end()) { + create_fns_[pass_type][pass_name] = create_fn; + GELOGI("UbFusionPass[type=%d,name=%s]: the pass type already exists.", pass_type, pass_name.c_str()); + return; + } + + std::map create_fn_map; + create_fn_map[pass_name] = create_fn; + create_fns_[pass_type] = create_fn_map; + GELOGI("UbFusionPass[type=%d,name=%s]: the pass type does not exists.", pass_type, pass_name.c_str()); + } + + std::map GetCreateFn(const BufferFusionPassType &pass_type) { + std::lock_guard lock(mu_); + std::map result; + auto iter = create_fns_.find(pass_type); + if (iter == create_fns_.end()) { + return result; + } + return iter->second; + } + + private: + std::mutex mu_; + std::map> create_fns_; +}; + +BufferFusionPassRegistry::BufferFusionPassRegistry() { + impl_ = std::unique_ptr(new (std::nothrow) BufferFusionPassRegistryImpl); +} + +BufferFusionPassRegistry::~BufferFusionPassRegistry() {} + +BufferFusionPassRegistry &BufferFusionPassRegistry::GetInstance() { + static BufferFusionPassRegistry instance; + return instance; +} + +void BufferFusionPassRegistry::RegisterPass(const BufferFusionPassType &pass_type, const std::string &pass_name, + CreateFn create_fn) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "UbFusionPass[type=%d,name=%s]: failed to register the ub fusion pass", pass_type, + pass_name.c_str()); + return; + } + impl_->RegisterPass(pass_type, pass_name, create_fn); +} + +std::map BufferFusionPassRegistry::GetCreateFnByType( + const BufferFusionPassType &pass_type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "UbFusionPass[type=%d]: failed to create the ub fusion pass", pass_type); + return std::map{}; + } + return impl_->GetCreateFn(pass_type); +} + +BufferFusionPassRegistrar::BufferFusionPassRegistrar(const BufferFusionPassType &pass_type, + const std::string &pass_name, + BufferFusionPassBase *(*create_fn)()) { + if (pass_type < BUILT_IN_AI_CORE_BUFFER_FUSION_PASS || pass_type >= BUFFER_FUSION_PASS_TYPE_RESERVED) { + GELOGE(ge::PARAM_INVALID, "The pass_type[%d] is not supported.", pass_type); + return; + } + + if (pass_name.empty()) { + GELOGE(ge::PARAM_INVALID, "Failed to register the ub fusion pass, the pass name is empty."); + return; + } + + BufferFusionPassRegistry::GetInstance().RegisterPass(pass_type, pass_name, create_fn); +} + +} // namespace fe \ No newline at end of file diff --git a/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc new file mode 100644 index 00000000..cd22892a --- /dev/null +++ b/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc @@ -0,0 +1,302 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" +#include +#include +#include "graph/debug/ge_log.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +using std::map; +using std::string; +using std::vector; + +namespace fe { +inline bool IsAddOverflow(int64_t a, int64_t b) { + return ((b > 0) && (a > ((int64_t)INT64_MAX - b))) || ((b < 0) && (a < ((int64_t)INT64_MIN - b))); +} + +BufferFusionPattern::BufferFusionPattern(string name, int64_t max_count) + : name_(name), op_max_count_(max_count), error_count_(0) {} + +BufferFusionPattern::~BufferFusionPattern() { + for (auto op : ops_) { + if (op == nullptr) { + continue; + } + delete (op); + } +} + +/* + * @brief: add op desc info + * @param [in] desc_name: node desc name + * @param [in] types: node desc type + * @param [in] repeate_min: the min count for fusion match, + * patter match failed if real count lower than the + * value + * @param [in] repeate_max: the max count for fusion match, + * the op will be ignored if current match count equla + * with the value + * @return BufferFusionPattern: pattern object + */ + +BufferFusionPattern &BufferFusionPattern::AddOpDesc(const std::string &desc_name, const std::vector &types, + int64_t repeate_min, int64_t repeate_max, int64_t group_id) { + if (desc_name.empty()) { + GELOGW("Desc_name cannot be empty."); + error_count_++; + return *this; + } + + if (repeate_min > repeate_max) { + GELOGW("Repeat_min can not lager than repeate_max, desc name is [%s], min is [%ld], max is [%ld]", + desc_name.c_str(), repeate_min, repeate_max); + error_count_++; + return *this; + } + + if (GetOpDesc(desc_name) != nullptr) { + GELOGW("Desc_name repeated. (desc_name:%s)", desc_name.c_str()); + error_count_++; + return *this; + } + + BufferFusionOpDesc *op = new (std::nothrow) BufferFusionOpDesc(); + if (op == nullptr) { + GELOGW("New an object failed."); + error_count_++; + return *this; + } + + op->desc_name = desc_name; + op->types = types; + op->repeate_min = repeate_min; + op->repeate_max = repeate_max; + op->repeate_curr = 0; + op->group_id = group_id; + op->match_status = false; + op->out_branch_type = TBE_OUTPUT_BRANCH_DEFAULT; + op->ignore_input_num = false; + op->ignore_output_num = false; + if (repeate_max > repeate_min) { + for (int64_t i = repeate_min; i < repeate_max; i++) { + op->multi_output_skip_status.insert(std::pair(i, SkipStatus::DISABLED)); + } + } + ops_.push_back(op); + op_map_[desc_name] = op; + + op->outputs.clear(); + return *this; +} + +/* + * @brief: set output desc info + * @param [in] desc_name: node desc name + * @param [in] output_ids: output desc + * @param [in] relation: output desc relation (1: serial, 2:parallel) + * @return BufferFusionPattern: pattern object + */ +BufferFusionPattern &BufferFusionPattern::SetOutputs(const string &desc_name, const std::vector &output_ids, + int64_t relation, bool ignore_input_num, bool ignore_output_num) { + if (desc_name.empty()) { + GELOGW("Desc_name cannot be empty."); + error_count_++; + return *this; + } + + BufferFusionOpDesc *op_desc = GetOpDesc(desc_name); + if (op_desc == nullptr) { + GELOGW("Desc_name not exist. (desc_name:%s)", desc_name.c_str()); + error_count_++; + return *this; + } + + op_desc->ignore_input_num = ignore_input_num; + op_desc->ignore_output_num = ignore_output_num; + if (op_desc->out_branch_type == TBE_OUTPUT_BRANCH_DEFAULT) { + op_desc->out_branch_type = relation; + } + + UpdateSkipStatus(op_desc); + + // support one multi output for one optype + for (const string &output_id : output_ids) { + BufferFusionOpDesc *output_op_desc = GetOpDesc(output_id); + if (output_op_desc == nullptr) { + GELOGW("Desc_name not exist. (desc_name:%s)", desc_name.c_str()); + if (IsAddOverflow(error_count_, 1) != SUCCESS) { + GELOGW("errorCount_++ overflow. (desc_name:%s)", desc_name.c_str()); + return *this; + } + error_count_++; + return *this; + } + if (op_desc == output_op_desc) { + continue; + } + + op_desc->outputs.push_back(output_op_desc); + output_op_desc->inputs.push_back(op_desc); + + if (op_desc->out_branch_type != relation) { + GELOGW("Failed to set outputs relation: curr is [%ld], new is [%ld].", op_desc->out_branch_type, relation); + return *this; + } + } + return *this; +} + +/* + * @brief: get output desc info + * @param [in] op_desc: current desc + * @param [out] outputs: candidate output desc set + * @return bool: get output desc ok or not + */ +bool BufferFusionPattern::GetOutputs(BufferFusionOpDesc *op_desc, std::vector &outputs, + bool ignore_repeat) { + if (op_desc == nullptr) { + GELOGW("failed to get outputs: op_desc is null."); + return false; + } + string desc_n = op_desc->desc_name; + + // add curr desc can be reused while repeate_curr < repeate_max + if (!ignore_repeat && op_desc->repeate_curr < op_desc->repeate_max) { + outputs.push_back(op_desc); + } + + // check candidate desc + for (auto desc : op_desc->outputs) { + if (desc == nullptr) { + GELOGD("desc[%s] has null output desc.", desc_n.c_str()); + continue; + } + // add out desc + outputs.push_back(desc); + + // add sub outdescs while repeate_min == 0 + if (desc->repeate_min == 0) { + std::vector sub_output; + if (GetOutputs(desc, sub_output, true)) { + for (const auto &sub_desc : sub_output) { + outputs.push_back(sub_desc); + } + } + } + } + + return true; +} + +/* + * @brief: set fusion pattern head + * @param [in] head_ids: node list + * @return bool: set head desc ok or not + */ +BufferFusionPattern &BufferFusionPattern::SetHead(const std::vector &head_ids) { + if (head_ids.empty()) { + GELOGW("input vector is empty."); + error_count_++; + return *this; + } + for (const string &head_id : head_ids) { + BufferFusionOpDesc *head_op_desc = GetOpDesc(head_id); + if (head_op_desc == nullptr) { + GELOGW("descName not exist. (desc_name:%s)", head_id.c_str()); + if (IsAddOverflow(error_count_, 1) != SUCCESS) { + GELOGW("errorCount_++ overflow. (desc_name:%s)", head_id.c_str()); + return *this; + } + error_count_++; + return *this; + } + // Head desc repeat number can not excceed 1 + // if must be excceed 1, it can be realized by several descs + if (head_op_desc->repeate_max > 1) { + GELOGW("Head desc repeat number can not excceed 1, head desc name is [%s], actual repeate_max is [%ld]", + head_id.c_str(), head_op_desc->repeate_max); + if (IsAddOverflow(error_count_, 1) != SUCCESS) { + GELOGW("errorCount_++ overflow. (desc_name:%s)", head_id.c_str()); + return *this; + } + error_count_++; + return *this; + } + head_.push_back(head_op_desc); + } + + // check head desc repeat min total value, it can not excceed 1 + int64_t desc_total_min = 0; + for (const auto &desc : head_) { + if (IsAddOverflow(desc_total_min, desc->repeate_min) != SUCCESS) { + GELOGW("desc_total_min + repeate_min overflow."); + return *this; + } + desc_total_min += desc->repeate_min; + } + + if (desc_total_min > 1) { + GELOGW("head desc repeat min total value can not be larger than 1, current is [%ld]", desc_total_min); + error_count_++; + return *this; + } + return *this; +} + +void BufferFusionPattern::UpdateSkipStatus(BufferFusionOpDesc *op_desc) { + if (op_desc->out_branch_type == TBE_OUTPUT_BRANCH_MULTI) { + for (auto &input_desc : op_desc->inputs) { + if (input_desc->types.size() != op_desc->types.size()) { + continue; + } + bool is_same_type = true; + for (size_t i = 0; i < input_desc->types.size(); i++) { + if (input_desc->types[i] != op_desc->types[i]) { + is_same_type = false; + break; + } + } + if (is_same_type && input_desc->ignore_output_num == true) { + for (int64_t i = input_desc->repeate_min; i < input_desc->repeate_max; i++) { + input_desc->multi_output_skip_status[i] = SkipStatus::AVAILABLE; + } + } + } + } +} + +/* + * @brief: get description ptr by name + * @param [in] desc_name: fusion pattern desc name + * @return BufferFusionOpDesc*: description ptr + */ +BufferFusionOpDesc *BufferFusionPattern::GetOpDesc(const string &desc_name) { + auto it = op_map_.find(desc_name); + if (it != op_map_.end()) return it->second; + + return nullptr; +} + +std::vector BufferFusionPattern::GetHead() { return head_; } + +std::string BufferFusionPattern::GetName() { return name_; } +int64_t BufferFusionPattern::GetOpMaxCount() { return op_max_count_; } +int64_t BufferFusionPattern::GetErrorCnt() { return error_count_; } + +std::vector BufferFusionPattern::GetOpDescs() { return ops_; } +} // namespace fe diff --git a/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc b/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc new file mode 100644 index 00000000..a95ec89a --- /dev/null +++ b/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc @@ -0,0 +1,134 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" +#include "graph/debug/ge_log.h" + +namespace fe { + +FusionStatisticRecorder::FusionStatisticRecorder(){}; + +FusionStatisticRecorder::~FusionStatisticRecorder(){}; + +FusionStatisticRecorder &FusionStatisticRecorder::Instance() { + static FusionStatisticRecorder fusion_statistic_recoder; + return fusion_statistic_recoder; +} + +void FusionStatisticRecorder::UpdateGraphFusionMatchTimes(FusionInfo &fusion_info) { + std::lock_guard lock_guard(mutex_); + if (fusion_info.GetMatchTimes() != 0) { + std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + fusion_info.GetGraphId(); + graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddMatchTimes(fusion_info.GetMatchTimes()); + GELOGD("session %d graph %s pass %s match_times value: %d", fusion_info.GetSessionId(), + fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), + graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetMatchTimes()); + } +} + +void FusionStatisticRecorder::UpdateGraphFusionEffectTimes(FusionInfo &fusion_info) { + std::lock_guard lock_guard(mutex_); + if (fusion_info.GetEffectTimes() != 0) { + std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + fusion_info.GetGraphId(); + graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddEffectTimes( + fusion_info.GetEffectTimes()); + GELOGD("session %d graph %s pass %s effect_times value: %d", fusion_info.GetSessionId(), + fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), + graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetEffectTimes()); + } +} + +void FusionStatisticRecorder::UpdateBufferFusionMatchTimes(FusionInfo &fusion_info) { + std::lock_guard lock_guard(mutex_); + if (fusion_info.GetMatchTimes() != 0) { + std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + fusion_info.GetGraphId(); + buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddMatchTimes(fusion_info.GetMatchTimes()); + GELOGD("ub session %d graph %s pass %s match_times value: %d", fusion_info.GetSessionId(), + fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), + buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetMatchTimes()); + } +} + +void FusionStatisticRecorder::UpdateBufferFusionEffectTimes(FusionInfo &fusion_info) { + std::lock_guard lock_guard(mutex_); + if (fusion_info.GetEffectTimes() != 0) { + std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + fusion_info.GetGraphId(); + buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddEffectTimes( + fusion_info.GetEffectTimes()); + GELOGD("ub session %d graph %s pass %s effect_times value: %d", fusion_info.GetSessionId(), + fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), + buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetEffectTimes()); + } +} + +void FusionStatisticRecorder::GetAndClearFusionInfo(const std::string &session_graph_id, + std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map) { + std::lock_guard lock_guard(mutex_); + GELOGD("start to get graph map size %d", graph_fusion_info_map_.size()); + GELOGD("start to get ub graph map size %d", buffer_fusion_info_map_.size()); + GetFusionInfo(session_graph_id, graph_fusion_info_map, buffer_fusion_info_map); + ClearFusionInfo(session_graph_id); +} + +void FusionStatisticRecorder::GetFusionInfo(const std::string &session_graph_id, + std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map) { + if (graph_fusion_info_map_.find(session_graph_id) != graph_fusion_info_map_.end()) { + graph_fusion_info_map = graph_fusion_info_map_[session_graph_id]; + } + if (buffer_fusion_info_map_.find(session_graph_id) != buffer_fusion_info_map_.end()) { + buffer_fusion_info_map = buffer_fusion_info_map_[session_graph_id]; + } +} + +void FusionStatisticRecorder::ClearFusionInfo(std::string session_graph_id) { + if (graph_fusion_info_map_.find(session_graph_id) != graph_fusion_info_map_.end()) { + graph_fusion_info_map_.erase(session_graph_id); + } + if (buffer_fusion_info_map_.find(session_graph_id) != buffer_fusion_info_map_.end()) { + buffer_fusion_info_map_.erase(session_graph_id); + } +} + +FusionInfo::FusionInfo(uint64_t session_id, std::string graph_id, std::string pass_name, int32_t match_times, + int32_t effect_times) + : session_id_(session_id), + graph_id_(std::move(graph_id)), + pass_name_(std::move(pass_name)), + match_times_(match_times), + effect_times_(effect_times) {} + +FusionInfo::~FusionInfo() {} + +void FusionInfo::AddMatchTimes(int32_t match_times) { this->match_times_ += match_times; } + +void FusionInfo::AddEffectTimes(int32_t effect_times) { this->effect_times_ += effect_times; } + +int32_t FusionInfo::GetMatchTimes() { return match_times_; } + +int32_t FusionInfo::GetEffectTimes() { return effect_times_; } + +std::string FusionInfo::GetGraphId() { return graph_id_; } + +std::string FusionInfo::GetPassName() { return pass_name_; } + +uint64_t FusionInfo::GetSessionId() { return session_id_; } + +void FusionInfo::SetMatchTimes(int32_t match_times) { this->match_times_ = match_times; } + +void FusionInfo::SetEffectTimes(int32_t effect_times) { this->effect_times_ = effect_times; } +} diff --git a/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc b/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc new file mode 100644 index 00000000..e55a4070 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" +#include +#include +#include +#include +#include +#include "graph/debug/ge_log.h" + +namespace fe { +class FusionPassRegistry::FusionPassRegistryImpl { + public: + void RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, + FusionPassRegistry::CreateFn create_fn) { + std::lock_guard lock(mu_); + + auto iter = create_fns_.find(pass_type); + if (iter != create_fns_.end()) { + create_fns_[pass_type][pass_name] = create_fn; + GELOGD("GraphFusionPass[type=%d,name=%s]: the pass type already exists.", pass_type, pass_name.c_str()); + return; + } + + std::map create_fn_map; + create_fn_map[pass_name] = create_fn; + create_fns_[pass_type] = create_fn_map; + GELOGD("GraphFusionPass[type=%d,name=%s]: the pass type does not exist.", pass_type, pass_name.c_str()); + } + + std::map GetCreateFn(const GraphFusionPassType &pass_type) { + std::lock_guard lock(mu_); + auto iter = create_fns_.find(pass_type); + if (iter == create_fns_.end()) { + return std::map{}; + } + return iter->second; + } + + private: + std::mutex mu_; + std::map> create_fns_; +}; + +FusionPassRegistry::FusionPassRegistry() { + impl_ = std::unique_ptr(new (std::nothrow) FusionPassRegistryImpl); +} + +FusionPassRegistry::~FusionPassRegistry() {} + +FusionPassRegistry &FusionPassRegistry::GetInstance() { + static FusionPassRegistry instance; + return instance; +} + +void FusionPassRegistry::RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, + CreateFn create_fn) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "GraphFusionPass[type=%d,name=%s]: failed to register the graph fusion pass.", + pass_type, pass_name.c_str()); + return; + } + impl_->RegisterPass(pass_type, pass_name, create_fn); +} + +std::map FusionPassRegistry::GetCreateFnByType( + const GraphFusionPassType &pass_type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "GraphFusionPass[type=%d]: failed to create the graph fusion pass.", pass_type); + return std::map{}; + } + return impl_->GetCreateFn(pass_type); +} + +FusionPassRegistrar::FusionPassRegistrar(const GraphFusionPassType &pass_type, const std::string &pass_name, + GraphPass *(*create_fn)()) { + if (pass_type < BUILT_IN_GRAPH_PASS || pass_type >= GRAPH_FUSION_PASS_TYPE_RESERVED) { + GELOGE(ge::PARAM_INVALID, "The pass_type[%d] is not supported.", pass_type); + return; + } + + if (pass_name.empty()) { + GELOGE(ge::PARAM_INVALID, "Failed to register the graph fusion pass, the pass name is empty."); + return; + } + FusionPassRegistry::GetInstance().RegisterPass(pass_type, pass_name, create_fn); +} + +} // namespace fe \ No newline at end of file diff --git a/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc b/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc new file mode 100644 index 00000000..9bd22630 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "graph/debug/ge_log.h" +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" + +namespace fe { + +#define FE_PATTERN_ERROR_RETURN_IF(condition, ...) \ + do { \ + if (condition) { \ + SetError(); \ + GELOGW(__VA_ARGS__); \ + return *this; \ + } \ + } while (0) + +FusionPattern::FusionPattern(string name) : name_(name), output_(nullptr), has_error_(false) {} + +FusionPattern::~FusionPattern() {} + +/** + * @ingroup fe + * @brief set pattern name + */ +FusionPattern &FusionPattern::SetName(const string &name) { + name_ = name; + return *this; +} + +/** + * @ingroup fe + * @brief add Op description with unknown number of args + */ +FusionPattern &FusionPattern::AddOpDesc(const string &id, const initializer_list &types) { + return AddOpDesc(id, vector(types)); +} + +/** + * @ingroup fe + * @brief add Op description with vector + */ +FusionPattern &FusionPattern::AddOpDesc(const string &id, const vector &types) { + FE_PATTERN_ERROR_RETURN_IF(id.empty(), "ID cannot be empty."); + + FE_PATTERN_ERROR_RETURN_IF(GetOpDesc(id) != nullptr, "ID already exists. (id:%s)", id.c_str()); + + std::shared_ptr op(new (std::nothrow) OpDesc()); + FE_PATTERN_ERROR_RETURN_IF(op == nullptr, "new an object failed."); + + op->id = id; + op->types = types; + op->repeatable = false; + op->is_output = false; + ops_.push_back(op); + op_map_[id] = op; + + return *this; +} + +/** + * @ingroup fe + * @brief set input Ops with unknown number of args + */ +FusionPattern &FusionPattern::SetInputs(const string &id, const initializer_list &input_ids) { + return SetInputs(id, vector(input_ids)); +} + +/** + * @ingroup fe + * @brief set input Ops with vector + */ +FusionPattern &FusionPattern::SetInputs(const string &id, const vector &input_ids) { + FE_PATTERN_ERROR_RETURN_IF(id.empty(), "Id cannot be empty."); + std::shared_ptr op_desc = GetOpDesc(id); + FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); + + op_desc->inputs.clear(); + + for (const string &input_id : input_ids) { + std::shared_ptr input_op_desc = GetOpDesc(input_id); + FE_PATTERN_ERROR_RETURN_IF(input_op_desc == nullptr, "Id does not exist. (id:%s)", input_id.c_str()); + op_desc->inputs.push_back(input_op_desc); + } + + return *this; +} + +/** + * @ingroup fe + * @brief set output Op + */ +FusionPattern &FusionPattern::SetOutput(const string &id) { + FE_PATTERN_ERROR_RETURN_IF(id.empty(), "Id cannot be empty."); + std::shared_ptr op_desc = GetOpDesc(id); + FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); + + op_desc->is_output = true; + + return *this; +} + +/** + * @ingroup fe + * @brief build pattern and check if error exists + */ +bool FusionPattern::Build() { + if (has_error_) { + return false; + } + + // check whether output node already exists + for (const std::shared_ptr op : ops_) { + if (op->is_output) { + if (output_ != nullptr) { + SetError(); + GELOGW("Multiple outputs are not supported. (id:%s)", op->id.c_str()); + break; + } + output_ = op; + } + } + + if (output_ == nullptr) { + SetError(); + GELOGW("Output must be set value."); + } + + return !has_error_; +} + +/** + * @ingroup fe + * @brief get pattern name + */ +const string &FusionPattern::GetName() const { return name_; } +/** + * @ingroup fe + * @brief get the OpDesc of input Ops (const) + */ +const vector> *FusionPattern::GetInputs( + const std::shared_ptr op_desc) { + if (op_desc == nullptr) { + return nullptr; + } + return &(op_desc->inputs); +} + +/** + * @ingroup fe + * @brief get the OpDesc of output Op + */ +const std::shared_ptr FusionPattern::GetOutput() const { return output_; } + +/** + * @ingroup fe + * @brief print pattern + */ +void FusionPattern::Dump() const { + std::ostringstream oss; + oss << std::endl << "Pattern (" << name_ << "):" << std::endl; + for (const auto &op : ops_) { + oss << " " << op->id << ": {"; + for (const string &type : op->types) { + oss << type << ", "; + } + oss << "} {"; + for (const auto &input : op->inputs) { + oss << input->id << ", "; + } + oss << "}"; + + if (op->is_output) { + oss << " [output]"; + } + + oss << std::endl; + } + GELOGD("%s", oss.str().c_str()); +} + +/** + * @ingroup fe + * @brief get OpDesc based on ID, return nullptr if failed + */ +std::shared_ptr FusionPattern::GetOpDesc(const string &id) const { + auto it = op_map_.find(id); + if (it != op_map_.end()) { + return it->second; + } + return nullptr; +} + +/** + * @ingroup fe + * @brief record error + */ +void FusionPattern::SetError() { has_error_ = true; } +} diff --git a/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc b/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc new file mode 100644 index 00000000..585c36f8 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc @@ -0,0 +1,209 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" +#include +#include +#include +#include "graph/debug/ge_log.h" +#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" +#include "register/graph_optimizer/fusion_common/graph_pass_util.h" +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" +#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" + +namespace fe { +GraphFusionPassBase::GraphFusionPassBase() { + pattern_fusion_base_pass_impl_ptr_ = std::make_shared(); +} + +GraphFusionPassBase::~GraphFusionPassBase() {} + +/** + * @ingroup fe + * @brief execute pass + */ +Status GraphFusionPassBase::Run(ge::ComputeGraph &graph) { + Mappings mappings; + bool is_patterns_ok = true; + // build Pattern + vector patterns; + pattern_fusion_base_pass_impl_ptr_->GetPatterns(patterns); + if (patterns.empty()) { + patterns = DefinePatterns(); + for (FusionPattern *pattern : patterns) { + if (pattern != nullptr) { + bool ok = pattern->Build(); + if (!ok) { + GELOGW("this pattern: %s build not success.", pattern->GetName().c_str()); + } + pattern->Dump(); + is_patterns_ok = is_patterns_ok && ok; + } + } + + pattern_fusion_base_pass_impl_ptr_->SetPatterns(patterns); + } + if (!is_patterns_ok) { + GELOGE(FAILED, "Patterns invalid."); + return FAILED; + } + + NodeMapInfoPtr node_map_info = nullptr; + if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { + node_map_info->run_count++; + } + // do matching and fusion for each pattern + bool final_changed = false; + for (const FusionPattern *pattern : patterns) { + if (pattern != nullptr) { + bool changed = false; + Status ret = RunOnePattern(graph, *pattern, changed); + if (ret != SUCCESS) { + GELOGW("run pattern %s not success, graph is not changed by it.", pattern->GetName().c_str()); + return ret; + } + final_changed = final_changed || changed; + } + } + return final_changed ? SUCCESS : NOT_CHANGED; +} + +/** + * @ingroup fe + * @brief do matching and fusion in graph based on the pattern + */ +Status GraphFusionPassBase::RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed) { + changed = false; + Mappings mappings; + int32_t effect_times = 0; + uint32_t graph_id = graph.GetGraphID(); + FusionInfo fusion_info(graph.GetSessionID(), to_string(graph_id), GetName(), static_cast(mappings.size()), + effect_times); + // match all patterns in graph, and save them to mappings + if (!MatchAll(graph, pattern, mappings)) { + GELOGD("GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", GetName().c_str(), + pattern.GetName().c_str(), mappings.size(), effect_times); + return SUCCESS; + } + + GELOGD("This graph has been matched with pattern[%s]. The mappings are as follows.", pattern.GetName().c_str()); + + // print the results of matching + pattern_fusion_base_pass_impl_ptr_->DumpMappings(pattern, mappings); + NodeMapInfoPtr node_map_info = nullptr; + // get nodes by type from node + (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph); + // do fusion for each mapping + for (Mapping &mapping : mappings) { + vector fus_nodes; + ge::NodePtr first_node = nullptr; + for (auto &item : mapping) { + if (!item.second.empty()) { + first_node = item.second[0]; + break; + } + } + + Status status = Fusion(graph, mapping, fus_nodes); + if (status != SUCCESS && status != NOT_CHANGED) { + GELOGE(status, "Fail to fuse the graph with pattern[%s].", pattern.GetName().c_str()); + return status; + } + + if (status == SUCCESS) { + effect_times++; + if (!fus_nodes.empty()) { + // add fusednode to node map info + for (ge::NodePtr &node : fus_nodes) { + GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node); + } + } + } + changed = changed || status == SUCCESS; + } + + // get match times and effect times + FusionStatisticRecorder &fusion_statistic_inst = FusionStatisticRecorder::Instance(); + fusion_info.SetMatchTimes(static_cast(mappings.size())); + fusion_info.SetEffectTimes(effect_times); + fusion_statistic_inst.UpdateGraphFusionMatchTimes(fusion_info); + fusion_statistic_inst.UpdateGraphFusionEffectTimes(fusion_info); + GELOGD("GraphId[%d], GraphFusionPass[%s]: pattern=%s, matched_times=%d, effected_times=%d.", graph_id, + GetName().c_str(), pattern.GetName().c_str(), static_cast(mappings.size()), effect_times); + return SUCCESS; +} + +/** + * @ingroup fe + * @brief match all nodes in graph according to pattern + */ +// match nodes in graph according to pattern, the algorithm is shown as +// following: +// 1. get output node from pattern +// 2. Search for candidate nodes in Graph (network Graph generated after +// parsing) according to Op Type and +// (optional), and add the candidate node to the list of candidates +// 3. For each Node in the candidate list, check whether the type and the number +// of precursors are consistent with the description of corresponding Op +// in pattern. If they are consistent, add the precursor Node to the +// candidate list, and add "PatternOp-GraphNode" to the mapping; otherwise, +// return an empty mapping +// 4. repeat step 3 until all the Ops in pattern are matched +// 5. if all the Ops in pattern are matched successfully, return the mapping of +// PatternOp and GraphNode +bool GraphFusionPassBase::MatchAll(ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings) { + vector matched_output_nodes; + + // find all the output nodes of pattern in the graph based on Op type + std::shared_ptr output_op_desc = pattern.GetOutput(); + if (output_op_desc == nullptr) { + return false; + } + + if (!pattern_fusion_base_pass_impl_ptr_->GetMatchOutputNodes(graph, pattern, matched_output_nodes)) { + return false; + } + + // begin matching from every output node + for (ge::NodePtr &output_node : matched_output_nodes) { + Mapping mapping; + if (pattern_fusion_base_pass_impl_ptr_->MatchFromOutput(output_node, output_op_desc, mapping)) { + mappings.push_back(mapping); + } + } + // if matching is successful, return true; otherwise false + return !mappings.empty(); +} + +/** + * @ingroup fe + * @brief get an op from mapping according to ID + */ +ge::NodePtr GraphFusionPassBase::GetNodeFromMapping(const string &id, const Mapping &mapping) { + for (auto &item : mapping) { + std::shared_ptr op_desc = item.first; + if (op_desc != nullptr && op_desc->id == id) { + if (item.second.empty()) { + return nullptr; + } else { + return item.second[0]; + } + } + } + return nullptr; +} + +} // namespace fe diff --git a/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc new file mode 100644 index 00000000..700569e5 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc @@ -0,0 +1,364 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" +#include +#include +#include +#include +#include +#include "graph/debug/ge_log.h" +#include "graph/utils/graph_utils.h" +#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" +#include "register/graph_optimizer/fusion_common/graph_pass_util.h" +#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" + +namespace fe { +static const string STREAM_LABEL = "_stream_label"; +PatternFusionBasePass::PatternFusionBasePass() { + pattern_fusion_base_pass_impl_ptr_ = std::make_shared(); +} + +PatternFusionBasePass::~PatternFusionBasePass() {} + +Status PatternFusionBasePass::Run(ge::ComputeGraph &graph, OpsKernelInfoStorePtr ops_kernel_info_store_ptr) { + // save the opskernelstoreptr which will be uesd while checking op support + pattern_fusion_base_pass_impl_ptr_->SetOpsKernelInfoStore(ops_kernel_info_store_ptr); + + return Run(graph); +} +/** + * @ingroup fe + * @brief execute pass + */ +Status PatternFusionBasePass::Run(ge::ComputeGraph &graph) { + Mappings mappings; + bool is_patterns_ok = true; + // build Pattern + vector patterns; + pattern_fusion_base_pass_impl_ptr_->GetPatterns(patterns); + if (patterns.empty()) { + patterns = DefinePatterns(); + for (FusionPattern *pattern : patterns) { + if (pattern != nullptr) { + bool ok = pattern->Build(); + if (!ok) { + GELOGW("this pattern: %s build not success.", pattern->GetName().c_str()); + } + pattern->Dump(); + is_patterns_ok = is_patterns_ok && ok; + } + } + + pattern_fusion_base_pass_impl_ptr_->SetPatterns(patterns); + } + + if (!is_patterns_ok) { + return FAILED; + } + NodeMapInfoPtr node_map_info = nullptr; + if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { + node_map_info->run_count++; + } + // do matching and fusion for each pattern + bool final_changed = false; + for (const FusionPattern *pattern : patterns) { + if (pattern != nullptr) { + bool changed = false; + Status ret = RunOnePattern(graph, *pattern, changed); + if (ret != SUCCESS) { + GELOGW("run pattern %s not success, graph is not changed by it.", pattern->GetName().c_str()); + return ret; + } + + final_changed = final_changed || changed; + } + } + return final_changed ? SUCCESS : NOT_CHANGED; +} + +static bool CheckStreamLabel(vector &fused_nodes) { + string stream_label = ""; + for (auto &n : fused_nodes) { + string stream_label_tmp = ""; + if (!ge::AttrUtils::GetStr(n->GetOpDesc(), STREAM_LABEL, stream_label_tmp)) { + stream_label_tmp = "null"; + } + if (stream_label == "") { + stream_label = stream_label_tmp; + } else if (stream_label != "" && stream_label != stream_label_tmp) { + return false; + } + } + return true; +} + +static bool SetStreamLabelToFusedNodes(vector &fused_nodes, ge::NodePtr first_node) { + string stream_label = ""; + if (ge::AttrUtils::GetStr(first_node->GetOpDesc(), STREAM_LABEL, stream_label)) { + for (ge::NodePtr &node : fused_nodes) { + if (!ge::AttrUtils::SetStr(node->GetOpDesc(), STREAM_LABEL, stream_label)) { + GELOGW("newNode set _stream_label error, fusion failed."); + return false; + } + } + } + return true; +} +/** + * @ingroup fe + * @brief do matching and fusion in graph based on the pattern + */ +Status PatternFusionBasePass::RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed) { + changed = false; + Mappings mappings; + int32_t effect_times = 0; + uint32_t graph_id = graph.GetGraphID(); + FusionInfo fusion_info(graph.GetSessionID(), to_string(graph_id), GetName(), static_cast(mappings.size()), + effect_times); + origin_op_anchors_map_.clear(); + // match all patterns in graph, and save them to mappings + if (!MatchAll(graph, pattern, mappings)) { + GELOGD("GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", GetName().c_str(), + pattern.GetName().c_str(), mappings.size(), effect_times); + return SUCCESS; + } + + GELOGD("This graph has been matched with pattern[%s]. The mappings are as follows.", pattern.GetName().c_str()); + + // print the results of matching + pattern_fusion_base_pass_impl_ptr_->DumpMappings(pattern, mappings); + NodeMapInfoPtr node_map_info = nullptr; + // get nodes by type from node + (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph); + // do fusion for each mapping + for (Mapping &mapping : mappings) { + vector fus_nodes; + ge::NodePtr first_node = nullptr; + for (auto &item : mapping) { + if (!item.second.empty()) { + first_node = item.second[0]; + break; + } + } + + Status status = Fusion(graph, mapping, fus_nodes); + if (!SetStreamLabelToFusedNodes(fus_nodes, first_node)) { + return FAILED; + } + + if (status != SUCCESS && status != NOT_CHANGED) { + GELOGE(status, "Fail to fuse the graph with pattern[%s].", pattern.GetName().c_str()); + return status; + } + + if (status == SUCCESS) { + effect_times++; + std::vector original_nodes; + for (auto &item : mapping) { + if (!item.second.empty()) { + for (auto &node : item.second) { + original_nodes.push_back(node); + } + } + } + SetDataDumpAttr(original_nodes, fus_nodes); + if (!fus_nodes.empty()) { + // add fusednode to node map info + for (ge::NodePtr &node : fus_nodes) { + (void)GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node); + } + } + } + changed = (changed || status == SUCCESS); + } + + // get match times and effect times + FusionStatisticRecorder &fusion_statistic_inst = FusionStatisticRecorder::Instance(); + fusion_info.SetMatchTimes(static_cast(mappings.size())); + fusion_info.SetEffectTimes(effect_times); + fusion_statistic_inst.UpdateGraphFusionMatchTimes(fusion_info); + fusion_statistic_inst.UpdateGraphFusionEffectTimes(fusion_info); + GELOGD("GraphId[%d], GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", graph_id, + GetName().c_str(), pattern.GetName().c_str(), mappings.size(), effect_times); + return SUCCESS; +} + +Status PatternFusionBasePass::SetDataDumpAttr(vector &original_nodes, vector &fus_nodes) { + for (auto &oriNode : original_nodes) { + auto itr = origin_op_anchors_map_.find(oriNode); + if (itr != origin_op_anchors_map_.end()) { + for (const auto &anchor_iter : itr->second) { + auto next_node_in_achor = anchor_iter.first; + auto fusion_node_out_data_anchor = next_node_in_achor->GetPeerOutAnchor(); + if (fusion_node_out_data_anchor == nullptr) { + GELOGW("fusionNodeOutDataAnchor is null"); + return FAILED; + } + + auto fusion_node = fusion_node_out_data_anchor->GetOwnerNode(); + if (fusion_node == nullptr) { + GELOGW("fusionNode is null"); + return FAILED; + } + + if (pattern_fusion_base_pass_impl_ptr_->IsNodesExist(fusion_node, fus_nodes)) { + auto origin_node_out_anchor = anchor_iter.second; + if (origin_node_out_anchor == nullptr) { + GELOGW("originNodeOutAnchor is null"); + return FAILED; + } + + auto origin_node = origin_node_out_anchor->GetOwnerNode(); + if (origin_node == nullptr) { + GELOGW("originNode is null"); + return FAILED; + } + + uint32_t origin_index = origin_node_out_anchor->GetIdx(); + uint32_t fusion_index = fusion_node_out_data_anchor->GetIdx(); + (void)GraphPassUtil::SetOutputDescAttr(origin_index, fusion_index, origin_node, fusion_node); + } + } + } + } + + for (auto &node : fus_nodes) { + GraphPassUtil::RecordOriginalNames(original_nodes, node); + } + if (fus_nodes.size() > 1) { + bool is_multi_op = true; + for (ge::NodePtr &node : fus_nodes) { + ge::AttrUtils::SetBool(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_IS_MULTIOP, is_multi_op); + } + } + + return SUCCESS; +} + +bool PatternFusionBasePass::CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) { + return pattern_fusion_base_pass_impl_ptr_->CheckOpSupported(op_desc_ptr); +} + +/** + * @ingroup fe + * @brief match all nodes in graph according to pattern + */ +// match nodes in graph according to pattern, the algorithm is shown as +// following: +// 1. get output node from pattern +// 2. Search for candidate nodes in Graph (network Graph generated after +// parsing) according to Op Type and +// (optional), and add the candidate node to the list of candidates +// 3. For each Node in the candidate list, check whether the type and the number +// of precursors are consistent with the description of corresponding Op +// in pattern. If they are consistent, add the precursor Node to the +// candidate list, and add "PatternOp-GraphNode" to the mapping; otherwise, +// return an empty mapping +// 4. repeat step 3 until all the Ops in pattern are matched +// 5. if all the Ops in pattern are matched successfully, return the mapping of +// PatternOp and GraphNode +bool PatternFusionBasePass::MatchAll(ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings) { + vector matched_output_nodes; + + // find all the output nodes of pattern in the graph based on Op type + std::shared_ptr output_op_desc = pattern.GetOutput(); + if (output_op_desc == nullptr) { + return false; + } + + if (!pattern_fusion_base_pass_impl_ptr_->GetMatchOutputNodes(graph, pattern, matched_output_nodes)) { + return false; + } + + // begin matching from every output node + for (ge::NodePtr &output_node : matched_output_nodes) { + Mapping mapping; + if (pattern_fusion_base_pass_impl_ptr_->MatchFromOutput(output_node, output_op_desc, mapping)) { + // node attr _stream_label must be equal + auto fusion_nodes = GetNodesFromMapping(mapping); + if (!CheckStreamLabel(fusion_nodes)) { + return false; + } + mappings.push_back(mapping); + + // Record output nodes anchor vs succeed node anchor map + RecordOutputAnchorMap(output_node); + } + } + // if matching is successful, return true; otherwise false + return !mappings.empty(); +} + +/* + * @brief: get all fusion nodes matched + * @param [in] mapping: fusion node group + * @return std::vector: all fusion nodes list + */ +vector PatternFusionBasePass::GetNodesFromMapping(const Mapping &mapping) { + std::vector nodes; + for (auto &item : mapping) { + for (const auto &node : item.second) { + nodes.push_back(node); + } + } + return nodes; +} + +/** + * @ingroup fe + * @brief get an op from mapping according to ID + */ +ge::NodePtr PatternFusionBasePass::GetNodeFromMapping(const string &id, const Mapping &mapping) { + for (auto &item : mapping) { + std::shared_ptr op_desc = item.first; + if (op_desc != nullptr && op_desc->id == id) { + if (item.second.empty()) { + return nullptr; + } else { + return item.second[0]; + } + } + } + return nullptr; +} + +void PatternFusionBasePass::RecordOutputAnchorMap(ge::NodePtr output_node) { + for (const auto &output_anchor : output_node->GetAllOutDataAnchors()) { + if (output_anchor == nullptr) { + continue; + } + + for (const auto &peer_in_anchor : output_anchor->GetPeerInDataAnchors()) { + if (peer_in_anchor == nullptr) { + continue; + } + + // Record anchor map + auto itr = origin_op_anchors_map_.find(output_node); + if (itr == origin_op_anchors_map_.end()) { + std::map anchorMap; + anchorMap[peer_in_anchor] = output_anchor; + origin_op_anchors_map_.emplace(make_pair(output_node, anchorMap)); + } else { + itr->second.emplace(make_pair(peer_in_anchor, output_anchor)); + } + } + } +} + +void PatternFusionBasePass::ClearOutputAnchorMap() { origin_op_anchors_map_.clear(); } +} // namespace fe diff --git a/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc new file mode 100644 index 00000000..ee9a2af2 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc @@ -0,0 +1,273 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" +#include "graph/debug/ge_log.h" +#include "register/graph_optimizer/fusion_common/graph_pass_util.h" + +namespace fe { +PatternFusionBasePassImpl::PatternFusionBasePassImpl() {} + +PatternFusionBasePassImpl::~PatternFusionBasePassImpl() { + for (auto pattern : patterns_) { + if (pattern != nullptr) { + delete pattern; + pattern = nullptr; + } + } +} + +void PatternFusionBasePassImpl::GetPatterns(vector &patterns) { patterns = patterns_; } + +void PatternFusionBasePassImpl::SetPatterns(vector &patterns) { patterns_ = patterns; } + +void PatternFusionBasePassImpl::SetOpsKernelInfoStore(OpsKernelInfoStorePtr ops_kernel_info_store_ptr) { + ops_kernel_info_store_ptr_ = ops_kernel_info_store_ptr; +} + +bool PatternFusionBasePassImpl::CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) { + std::string un_supported_reason; + + if (ops_kernel_info_store_ptr_ == nullptr) { + un_supported_reason = "opsKernelInfoStorePtr in PatternFusionBasePass is nullptr."; + return false; + } + + bool result; + result = ops_kernel_info_store_ptr_->CheckSupported(op_desc_ptr, un_supported_reason); + return result; +} + +bool PatternFusionBasePassImpl::IsNodesExist(ge::NodePtr current_node, std::vector &nodes) { + return find(nodes.begin(), nodes.end(), current_node) != nodes.end(); +} + +bool PatternFusionBasePassImpl::IsMatched(std::shared_ptr op_desc, const ge::NodePtr node, + const Mapping &mapping) { + if (op_desc == nullptr || node == nullptr) { + GELOGD("opDesc or node could not be null"); + return false; + } + + const auto iter = mapping.find(op_desc); + + // check op_desc does not exist in mapping + return iter != mapping.end() && (find(iter->second.begin(), iter->second.end(), node) != iter->second.end()); +} + +void PatternFusionBasePassImpl::DumpMappings(const FusionPattern &pattern, const Mappings &mappings) { + std::ostringstream oss; + oss << std::endl << "Mappings of pattern "; + oss << pattern.GetName() << ":" << std::endl; + for (size_t i = 0; i < mappings.size(); i++) { + const Mapping &mapping = mappings[i]; + oss << " Mapping " << (i + 1) << "/" << mappings.size() << ":" << std::endl; + for (const auto &item : mapping) { + std::shared_ptr op_desc = item.first; + const ge::NodePtr node = item.second[0]; + if (op_desc != nullptr && node != nullptr) { + oss << " " << op_desc->id << " -> " << node->GetName() << std::endl; + } + } + } + GELOGD("%s", oss.str().c_str()); +} + +bool PatternFusionBasePassImpl::IsOpTypeExist(const string &type, const vector &types) { + return find(types.begin(), types.end(), type) != types.end(); +} + +bool PatternFusionBasePassImpl::MatchFromOutput(ge::NodePtr output_node, std::shared_ptr output_op_desc, + Mapping &mapping) { + if (output_node == nullptr) { + GELOGW("outputNode is null, pattern matching failed"); + return false; + } + + if (output_op_desc == nullptr) { + GELOGW("outputOpDesc is null, pattern matching failed"); + return false; + } + + vector candidate_nodes = {output_node}; + vector> candidate_op_descs = {output_op_desc}; + + // store the nodes matched + mapping[output_op_desc].push_back(output_node); + + // match candidate node one by one + while (!candidate_nodes.empty() && !candidate_op_descs.empty()) { + // get the first candidate node + bool result = MatchFromOutput(candidate_nodes, candidate_op_descs, mapping); + if (!result) { + return false; + } + + // current op is matched successfully, thus remove it from candidate list + candidate_nodes.erase(candidate_nodes.begin()); + candidate_op_descs.erase(candidate_op_descs.begin()); + + // the sizes of candidate_nodes and candidate_op_descs should always keep the same + if (candidate_nodes.size() != candidate_op_descs.size()) { + GELOGW("candidateNodes size does not equal to candidate_op_descs size, pattern matching failed."); + return false; + } + } + + // if candidate_nodes(or candidate_op_descs) is empty, the matching is done + // successfully + return candidate_op_descs.empty(); +} + +bool PatternFusionBasePassImpl::MatchFromOutput(vector &candidate_nodes, + vector> &candidate_op_descs, Mapping &mapping) { + if (candidate_nodes.empty() || candidate_op_descs.empty()) { + GELOGW("candidateNodes or candidate_op_descs is empty, pattern matching failed."); + return false; + } + ge::NodePtr node = candidate_nodes.front(); + std::shared_ptr op_desc = candidate_op_descs.front(); + string op_id = op_desc->id; + // add the input nodes into candidate list + const vector> *inputs_desc = FusionPattern::GetInputs(op_desc); + if (inputs_desc == nullptr) { + GELOGW("Op[%s]: the inputs_desc is null, pattern matching failed.", op_id.c_str()); + return false; + } + + if (inputs_desc->empty()) { + return true; + } + + if (node->GetInDataNodes().empty()) { + GELOGW("Op[%s]: in data node or inputs_desc is empty, pattern matching failed.", op_id.c_str()); + return false; + } + + // set flag for edge using + const std::unique_ptr usage_flags(new (std::nothrow) bool[inputs_desc->size()]{}); + + // order the input edges, and the order should also be the rule of pattern + // setting + std::vector in_anchors; + GetInDataAnchors(node, in_anchors); + if (in_anchors.empty()) { + GELOGW("Op[%s]: in data anchors are empty, pattern matching failed.", op_id.c_str()); + return false; + } + + std::sort(in_anchors.begin(), in_anchors.end(), + [](ge::InDataAnchorPtr a, ge::InDataAnchorPtr b) { return a->GetIdx() < b->GetIdx(); }); + + for (const auto &in_anchor : in_anchors) { + ge::NodePtr input_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode(); + for (uint32_t j = 0; j < inputs_desc->size(); j++) { + std::shared_ptr input_desc = inputs_desc->at(j); + if (input_desc == nullptr) { + GELOGW("Op[%s]: input_desc is null, pattern matching failed.", op_id.c_str()); + return false; + } + + bool condi = + (IsOpTypeExist(ge::NodeUtils::GetNodeType(*input_node), input_desc->types) || input_desc->types.empty()) && + (!usage_flags[j] || input_desc->repeatable); + if (!condi) { + continue; + } + // some nodes might be the input of multiple nodes, we use + // IsMatched() to avoid repeat + if (!IsMatched(input_desc, input_node, mapping)) { + candidate_nodes.push_back(input_node); + candidate_op_descs.push_back(input_desc); + // store the matched node + mapping[input_desc].push_back(input_node); + } + usage_flags[j] = true; + break; + } + } + + // return false if not all edges are matched + if (!MatchAllEdges(inputs_desc->size(), usage_flags)) { + GELOGW("Op[%s]: not all inputs are matched, pattern matching failed.", op_id.c_str()); + return false; + } + return true; +} + +bool PatternFusionBasePassImpl::MatchAllEdges(const size_t &input_size, const std::unique_ptr &usage_flags) { + for (size_t i = 0; i != input_size; i++) { + if (!usage_flags[i]) { + return false; + } + } + return true; +} + +void PatternFusionBasePassImpl::GetInDataAnchors(const ge::NodePtr &node, + std::vector &in_anchor_vec) { + for (auto in_anchor : node->GetAllInDataAnchors()) { + if (in_anchor == nullptr || in_anchor->GetPeerOutAnchor() == nullptr || + in_anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr) { + continue; + } + in_anchor_vec.push_back(in_anchor); + } +} + +bool PatternFusionBasePassImpl::GetMatchOutputNodes(ge::ComputeGraph &graph, const FusionPattern &pattern, + vector &matched_output_nodes) { + std::shared_ptr output_op_desc = pattern.GetOutput(); + if (output_op_desc == nullptr) { + GELOGW("outputOpDesc is null, pattern matching failed"); + return false; + } + + NodeMapInfoPtr node_map_info = nullptr; + // get nodes by type from node + if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { + for (auto &OutOpType : output_op_desc->types) { + auto iter = node_map_info->node_type_map->find(OutOpType); + if (iter != node_map_info->node_type_map->end()) { + for (auto &node_ptr : iter->second) { + if (node_ptr->GetInDataNodes().empty() && node_ptr->GetOutAllNodes().empty()) { + continue; + } + if (ge::NodeUtils::GetNodeType(*node_ptr) == OutOpType) { + matched_output_nodes.push_back(node_ptr); + } + } + } + } + } else { // for each graph to find type + for (ge::NodePtr &n : graph.GetDirectNode()) { + if (n == nullptr) { + GELOGW("node from graph is null, pattern matching failed"); + return false; + } + + if (IsOpTypeExist(ge::NodeUtils::GetNodeType(*n), output_op_desc->types)) { + matched_output_nodes.push_back(n); + } + } + } + + if (matched_output_nodes.empty()) { + return false; + } + return true; +} +} diff --git a/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h new file mode 100644 index 00000000..e3891838 --- /dev/null +++ b/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FE_PATTERN_FUSION_BASE_PASS_IMPL_H +#define FE_PATTERN_FUSION_BASE_PASS_IMPL_H + +#include +#include +#include +#include +#include +#include + +#include "common/opskernel/ops_kernel_info_store.h" +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" + +using std::initializer_list; +using std::map; +using std::string; +using std::vector; + +using namespace std; + +namespace fe { + +using OpDesc = FusionPattern::OpDesc; +using Mapping = map, vector>; +using Mappings = std::vector; +using OpsKernelInfoStorePtr = std::shared_ptr; + +/** Base pattern impl + * @ingroup FUSION_PASS_GROUP + * @note New virtual methods should be append at the end of this class + */ +class PatternFusionBasePassImpl { + public: + PatternFusionBasePassImpl(); + + virtual ~PatternFusionBasePassImpl(); + + void GetPatterns(vector &patterns); + + void SetPatterns(vector &patterns); + + void SetOpsKernelInfoStore(OpsKernelInfoStorePtr ops_kernel_info_store_ptr); + + PatternFusionBasePassImpl &operator=(const PatternFusionBasePassImpl &) = delete; + + bool CheckOpSupported(const ge::OpDescPtr &op_desc_ptr); + + bool IsNodesExist(ge::NodePtr current_node, std::vector &nodes); + + bool IsMatched(std::shared_ptr op_desc, const ge::NodePtr node, const Mapping &mapping); + + void DumpMappings(const FusionPattern &pattern, const Mappings &mappings); + + bool IsOpTypeExist(const string &type, const vector &types); + + bool MatchFromOutput(ge::NodePtr output_node, std::shared_ptr output_op_desc, Mapping &mapping); + + std::string GetNodeType(ge::NodePtr node); + + bool GetMatchOutputNodes(ge::ComputeGraph &graph, const FusionPattern &pattern, + vector &matched_output_nodes); + + private: + vector patterns_; + + OpsKernelInfoStorePtr ops_kernel_info_store_ptr_; + + bool MatchFromOutput(vector &candidate_nodes, vector> &candidate_op_descs, + Mapping &mapping); + + bool MatchAllEdges(const size_t &input_size, const std::unique_ptr &usage_flags); + + void GetInDataAnchors(const ge::NodePtr &node, std::vector &in_anchor_vec); +}; + +} // namespace fe + +#endif // FE_PATTERN_FUSION_BASE_PASS_H diff --git a/ge/ge_runtime/proto/task.pb.h b/metadef/register/host_cpu_context.cc similarity index 75% rename from ge/ge_runtime/proto/task.pb.h rename to metadef/register/host_cpu_context.cc index 490289ac..e10dc792 100644 --- a/ge/ge_runtime/proto/task.pb.h +++ b/metadef/register/host_cpu_context.cc @@ -14,14 +14,12 @@ * limitations under the License. */ -// Generated by the protocol buffer compiler. DO NOT EDIT! -// source: task.proto +#include "register/host_cpu_context.h" -#ifndef STUB_TASK_PROTO_H -#define STUB_TASK_PROTO_H - -namespace domi { -class TaskDef; -} - -#endif // STUB_TASK_PROTO_H +namespace ge { +class HostCpuContext::Impl { + public: + Impl() = default; + ~Impl() = default; +}; +} // namespace ge \ No newline at end of file diff --git a/metadef/register/infer_data_slice_registry.cc b/metadef/register/infer_data_slice_registry.cc new file mode 100644 index 00000000..06423330 --- /dev/null +++ b/metadef/register/infer_data_slice_registry.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/operator_factory_impl.h" +#include "register/infer_data_slice_registry.h" + +namespace ge { +InferDataSliceFuncRegister::InferDataSliceFuncRegister(const char *operator_type, + const InferDataSliceFunc &infer_data_slice_func) { + (void)OperatorFactoryImpl::RegisterInferDataSliceFunc(operator_type, infer_data_slice_func); +} +} // namespace ge diff --git a/metadef/register/module.mk b/metadef/register/module.mk new file mode 100644 index 00000000..e853ce28 --- /dev/null +++ b/metadef/register/module.mk @@ -0,0 +1,213 @@ +LOCAL_PATH := $(call my-dir) + + +local_lib_src_files := register.cpp \ + ops_kernel_builder_registry.cc \ + graph_optimizer/graph_fusion/graph_fusion_pass_base.cc \ + graph_optimizer/graph_fusion/fusion_pass_registry.cc \ + graph_optimizer/graph_fusion/fusion_pattern.cc \ + graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc \ + graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc \ + graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h \ + graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc \ + graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc \ + graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc \ + graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc \ + register_format_transfer.cc \ + op_kernel_registry.cpp \ + auto_mapping_util.cpp \ + host_cpu_context.cc \ + tensor_assign.cpp \ + infer_data_slice_registry.cc \ + scope/scope_graph.cc \ + scope/scope_pass.cc \ + scope/scope_pattern.cc \ + scope/scope_util.cc \ + scope/scope_pass_registry.cc \ + ./proto/tensorflow/attr_value.proto \ + ./proto/tensorflow/function.proto \ + ./proto/tensorflow/graph.proto \ + ./proto/tensorflow/node_def.proto \ + ./proto/tensorflow/op_def.proto \ + ./proto/tensorflow/resource_handle.proto \ + ./proto/tensorflow/tensor.proto \ + ./proto/tensorflow/tensor_shape.proto \ + ./proto/tensorflow/types.proto \ + ./proto/tensorflow/versions.proto \ + ./proto/task.proto \ + ./proto/om.proto \ + +local_lib_inc_path := \ + inc \ + metadef/inc \ + graphengine/inc \ + inc/external \ + metadef/inc/external \ + graphengine/inc/external \ + metadef/inc/external/graph \ + metadef/inc/graph \ + metadef/inc/common \ + graphengine/inc/framework \ + metadef \ + metadef/graph \ + third_party/protobuf/include \ + libc_sec/include \ + third_party/json/include \ + +tiling_src_files := op_tiling.cpp \ + op_tiling_registry.cpp \ + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libop_tiling_o2 + +LOCAL_CFLAGS += -std=c++11 -O2 -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := + +LOCAL_SRC_FILES := $(tiling_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_C_INCLUDES := $(local_lib_inc_path) + +include ${BUILD_HOST_STATIC_LIBRARY} + + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libregister + +LOCAL_CFLAGS += -std=c++11 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_WHOLE_STATIC_LIBRARIES := libop_tiling_o2 \ + +LOCAL_SHARED_LIBRARIES := libascend_protobuf \ + libc_sec \ + libslog \ + libgraph \ + +LOCAL_SRC_FILES := $(local_lib_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_EXPORT_C_INCLUDE_DIRS := $(generated_sources_dir)/proto/$(LOCAL_PATH) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +LOCAL_C_INCLUDES += LOCAL_EXPORT_C_INCLUDE_DIRS + +include ${BUILD_HOST_SHARED_LIBRARY} + + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := libop_tiling_o2 + +LOCAL_CFLAGS += -std=c++11 -O2 -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := + +LOCAL_SRC_FILES := $(tiling_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +include ${BUILD_STATIC_LIBRARY} + + +include $(CLEAR_VARS) +LOCAL_MODULE := libregister + +LOCAL_CFLAGS += -std=c++11 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_WHOLE_STATIC_LIBRARIES := libop_tiling_o2 + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := libascend_protobuf \ + libc_sec \ + libslog \ + libgraph \ + +LOCAL_SRC_FILES := $(local_lib_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_EXPORT_C_INCLUDE_DIRS := $(generated_sources_dir)/proto/$(LOCAL_PATH) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +LOCAL_C_INCLUDES += LOCAL_EXPORT_C_INCLUDE_DIRS + +include ${BUILD_SHARED_LIBRARY} + +#compiler static libregister for host +include $(CLEAR_VARS) +LOCAL_MODULE := libregister + +LOCAL_CFLAGS += -std=c++11 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := \ + libgraph \ + libascend_protobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + +LOCAL_SRC_FILES := $(local_lib_src_files) $(tiling_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_EXPORT_C_INCLUDE_DIRS := $(generated_sources_dir)/proto/$(LOCAL_PATH) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +LOCAL_C_INCLUDES += LOCAL_EXPORT_C_INCLUDE_DIRS + +LOCAL_UNINSTALLABLE_MODULE := false +include ${BUILD_HOST_STATIC_LIBRARY} + + +#compiler static libregister for device +include $(CLEAR_VARS) +LOCAL_MODULE := libregister + +LOCAL_CFLAGS += -std=c++11 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := \ + libgraph \ + libascend_protobuf \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libslog \ + +LOCAL_SRC_FILES := $(local_lib_src_files) $(tiling_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_EXPORT_C_INCLUDE_DIRS := $(generated_sources_dir)/proto/$(LOCAL_PATH) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +LOCAL_C_INCLUDES += LOCAL_EXPORT_C_INCLUDE_DIRS + +LOCAL_UNINSTALLABLE_MODULE := false +include ${BUILD_STATIC_LIBRARY} + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := libregister + +LOCAL_CFLAGS += -std=c++11 -Dgoogle=ascend_private -Wno-deprecated-declarations +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := libascend_protobuf \ + libc_sec \ + libslog \ + libgraph \ + +LOCAL_SRC_FILES := $(local_lib_src_files) $(tiling_src_files) + +generated_sources_dir := $(call local-generated-sources-dir) +LOCAL_EXPORT_C_INCLUDE_DIRS := $(generated_sources_dir)/proto/$(LOCAL_PATH) +LOCAL_C_INCLUDES := $(local_lib_inc_path) +LOCAL_C_INCLUDES += LOCAL_EXPORT_C_INCLUDE_DIRS + +include ${BUILD_LLT_SHARED_LIBRARY} diff --git a/metadef/register/op_kernel_registry.cpp b/metadef/register/op_kernel_registry.cpp new file mode 100644 index 00000000..f8f8fb0d --- /dev/null +++ b/metadef/register/op_kernel_registry.cpp @@ -0,0 +1,78 @@ +#include "register/op_kernel_registry.h" +#include +#include +#include "graph/debug/ge_log.h" + +namespace ge { +class OpKernelRegistry::OpKernelRegistryImpl { + public: + void RegisterHostCpuOp(const std::string &op_type, OpKernelRegistry::CreateFn create_fn) { + std::lock_guard lock(mu_); + create_fns_[op_type] = create_fn; + } + + OpKernelRegistry::CreateFn GetCreateFn(const std::string &op_type) { + std::lock_guard lock(mu_); + auto it = create_fns_.find(op_type); + if (it == create_fns_.end()) { + return nullptr; + } + + return it->second; + } + + private: + std::mutex mu_; + std::map create_fns_; +}; + +OpKernelRegistry::OpKernelRegistry() { + impl_ = std::unique_ptr(new(std::nothrow) OpKernelRegistryImpl); +} + +OpKernelRegistry::~OpKernelRegistry() { +} + +bool OpKernelRegistry::IsRegistered(const std::string &op_type) { + if (impl_ == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to invoke IsRegistered %s, OpKernelRegistry is not properly initialized", + op_type.c_str()); + return false; + } + + return impl_->GetCreateFn(op_type) != nullptr; +} + +void OpKernelRegistry::RegisterHostCpuOp(const std::string &op_type, CreateFn create_fn) { + if (impl_ == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to register %s, OpKernelRegistry is not properly initialized", op_type.c_str()); + return; + } + + impl_->RegisterHostCpuOp(op_type, create_fn); +} +std::unique_ptr OpKernelRegistry::CreateHostCpuOp(const std::string &op_type) { + if (impl_ == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to create op for %s, OpKernelRegistry is not properly initialized", + op_type.c_str()); + return nullptr; + } + + auto create_fn = impl_->GetCreateFn(op_type); + if (create_fn == nullptr) { + GELOGD("Host Cpu op is not registered. op type = %s", op_type.c_str()); + return nullptr; + } + + return std::unique_ptr(create_fn()); +} + +HostCpuOpRegistrar::HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()) { + if (op_type == nullptr) { + GELOGE(PARAM_INVALID, "Failed to register host cpu op, op type is null"); + return; + } + + OpKernelRegistry::GetInstance().RegisterHostCpuOp(op_type, create_fn); +} +} // namespace ge \ No newline at end of file diff --git a/metadef/register/op_tiling.cpp b/metadef/register/op_tiling.cpp new file mode 100644 index 00000000..38fc515f --- /dev/null +++ b/metadef/register/op_tiling.cpp @@ -0,0 +1,506 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/op_tiling.h" + +#include +#include +#include +#include "securec.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "graph/debug/ge_util.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/tensor_utils.h" +#include + +#define LOG_ENABLED(loglvl) CheckLogLevel(GE_MODULE_NAME, loglvl) + +namespace optiling { + +const char *COMPILE_INFO_JSON = "compile_info_json"; +const char *COMPILE_INFO_KEY = "compile_info_key"; + +const std::map DATATYPE_STRING_MAP{{ge::DT_FLOAT, "float32"}, + {ge::DT_FLOAT16, "float16"}, + {ge::DT_INT8, "int8"}, + {ge::DT_INT16, "int16"}, + {ge::DT_INT32, "int32"}, + {ge::DT_INT64, "int64"}, + {ge::DT_UINT8, "uint8"}, + {ge::DT_UINT16, "uint16"}, + {ge::DT_UINT32, "uint32"}, + {ge::DT_UINT64, "uint64"}, + {ge::DT_BOOL, "bool"}, + {ge::DT_DOUBLE, "double"}, + {ge::DT_DUAL, "dual"}, + {ge::DT_DUAL_SUB_INT8, "dual_sub_int8"}, + {ge::DT_DUAL_SUB_UINT8, "dual_sub_uint8"}}; + +bool FeedTeOpTensorArg(ge::OpDesc::Vistor &tensor_desc, std::vector &tensor_arg) { + for (auto &desc : tensor_desc) { + TeOpTensorArg arg_tensor; + TeOpTensor tensor; + arg_tensor.arg_type = TA_SINGLE; + tensor.shape = desc->GetShape().GetDims(); + tensor.ori_shape = desc->GetOriginShape().GetDims(); + + tensor.format = ge::TypeUtils::FormatToSerialString(desc->GetFormat()); + + tensor.ori_format = ge::TypeUtils::FormatToSerialString(desc->GetOriginFormat()); + + ge::DataType dtype = desc->GetDataType(); + auto dataTypeIter = DATATYPE_STRING_MAP.find(dtype); + if (dataTypeIter == DATATYPE_STRING_MAP.end()) { + GE_LOGE("datatype error %d", static_cast(dtype)); + return false; + } + tensor.dtype = dataTypeIter->second; + if (LOG_ENABLED(DLOG_INFO)) { + std::stringstream shapestr; + shapestr << "shape:["; + for (auto &i : tensor.shape) { + shapestr << i << ","; + } + shapestr << "], ori_shape:["; + for (auto &i : tensor.ori_shape) { + shapestr << i << ","; + } + shapestr << "], format:" << tensor.format; + shapestr << ", ori_format:" << tensor.ori_format; + shapestr << ", dtype: " << tensor.dtype; + GELOGI("calling optiling shape info: %s", shapestr.str().c_str()); + } + + arg_tensor.tensor.emplace_back(tensor); + tensor_arg.emplace_back(arg_tensor); + } + return true; +} + +void FeedTeOpConstTensor(const ge::Node &node, const ge::OpDescPtr &op_desc, + std::map &const_inputs) { + ge::Operator op = ge::OpDescUtils::CreateOperatorFromNode(node.shared_from_this()); + std::vector inferDepends = op_desc->GetOpInferDepends(); + + for (auto &depend : inferDepends) { + ge::Tensor data; + ge::graphStatus rc = op.GetInputConstData(depend.c_str(), data); + GELOGI("GetInputConstData: %s, %d", depend.c_str(), rc); + if (rc != ge::GRAPH_SUCCESS) { + continue; + } + + const uint8_t *pbuf = data.GetData(); + size_t buflen = data.GetSize(); + + GELOGI("Const input tensor data: %s, %p %zu", depend.c_str(), pbuf, buflen); + const_inputs.emplace(depend, TeConstTensorData{pbuf, buflen, data}); + } +} + +bool GetCompileInfo(const ge::OpDescPtr &op_desc, const char *op_type, const char *op_name, + OpCompileInfo &op_compile_info) { + bool bres = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY, op_compile_info.key); + if (!bres) { + GE_LOGE("Can not find the attribute %s. op_type:%s, op_name:%s", COMPILE_INFO_KEY, op_type, op_name); + return false; + } + + bres = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON, op_compile_info.str); + if (!bres) { + GE_LOGE("Can not find the attribute %s. op_type:%s, op_name:%s", COMPILE_INFO_JSON, op_type, op_name); + return false; + } + return true; +} + +void ParseShapeDesc(const nlohmann::json &shape, std::vector &tensors) { + TeOpTensor tensor; + if (shape.contains("shape")) { + tensor.shape = shape["shape"].get>(); + } + if (shape.contains("ori_shape")) { + tensor.ori_shape = shape["ori_shape"].get>(); + } + if (shape.contains("format")) { + tensor.format = shape["format"].get(); + } + if (shape.contains("ori_format")) { + tensor.ori_format = shape["ori_format"].get(); + } + if (shape.contains("dtype")) { + tensor.dtype = shape["dtype"].get(); + } + tensors.emplace_back(tensor); +} + +void ParseShapeDescList(const nlohmann::json &shape_list, std::vector &op_args) { + for (const auto &elem : shape_list) { + TeOpTensorArg tensor_arg; + tensor_arg.arg_type = TA_NONE; + + if (elem.is_array()) { + tensor_arg.arg_type = TA_LIST; + for (const auto &shape : elem) { + ParseShapeDesc(shape, tensor_arg.tensor); + } + } else { + tensor_arg.arg_type = TA_SINGLE; + ParseShapeDesc(elem, tensor_arg.tensor); + } + op_args.emplace_back(tensor_arg); + } +} + +template +void GetConstDataPointer(const nlohmann::json &json_array, std::vector &const_value) { + std::vector value = json_array.get>(); + uint8_t *pv_begin = reinterpret_cast(value.data()); + uint8_t *pv_end = pv_begin + value.size() * sizeof(T); + const_value = std::move(std::vector(pv_begin, pv_end)); +} + +bool CopyConstData(const std::string &dtype, const nlohmann::json &json_array, std::vector &value) { + if (dtype == "int8") { + GetConstDataPointer(json_array, value); + } else if (dtype == "uint8") { + GetConstDataPointer(json_array, value); + } else if (dtype == "int16") { + GetConstDataPointer(json_array, value); + } else if (dtype == "uint16") { + GetConstDataPointer(json_array, value); + } else if (dtype == "int32") { + GetConstDataPointer(json_array, value); + } else if (dtype == "uint32") { + GetConstDataPointer(json_array, value); + } else if (dtype == "int64") { + GetConstDataPointer(json_array, value); + } else if (dtype == "uint64") { + GetConstDataPointer(json_array, value); + } else if (dtype == "float32") { + GetConstDataPointer(json_array, value); + } else if (dtype == "double") { + GetConstDataPointer(json_array, value); + } else { + GE_LOGE("Unknown dtype: %s", dtype.c_str()); + return false; + } + return true; +} + +void ParseConstShapeDesc(const nlohmann::json &shape_json, std::map &const_tensors, + std::map> &const_values) { + std::vector shape; + std::string format_str; + std::string dtype_str; + + if (!shape_json.contains("const_value")) { + GELOGI("Not const tenosr"); + return; + } + if (!shape_json.contains("name")) { + GE_LOGE("const tensor has no name"); + return; + } + std::string name = shape_json["name"]; + + if (shape_json.contains("shape")) { + shape = shape_json["shape"].get>(); + } + if (shape_json.contains("format")) { + format_str = shape_json["format"].get(); + } + if (shape_json.contains("dtype")) { + dtype_str = shape_json["dtype"].get(); + } + + std::vector value; + bool bres = CopyConstData(dtype_str, shape_json["const_value"], value); + if (!bres) { + GE_LOGE("CopyConstData faild. buffer is null"); + return; + } + auto res = const_values.emplace(name, std::move(value)); + if (res.first == const_values.end()) { + return; // CodeDEX complains 'CHECK_CONTAINER_EMPTY' + } + + ge::Shape ge_shape(shape); + std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); + dtype_str = "DT_" + dtype_str; + ge::DataType ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); + std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); + ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); + ge::Tensor const_tensor(ge::TensorDesc(ge_shape, ge_format, ge_dtype), res.first->second); + const_tensors.emplace(name, std::make_tuple(const_tensor.GetData(), const_tensor.GetSize(), const_tensor)); + return; +} + +void ParseConstTensorList(const nlohmann::json &shape_list, std::map &const_tensors, + std::map> &const_values) { + for (const auto &elem : shape_list) { + if (elem.is_array()) { + for (const auto &shape : elem) { + ParseConstShapeDesc(shape, const_tensors, const_values); + } + } else { + ParseConstShapeDesc(elem, const_tensors, const_values); + } + } +} + +std::string DumpByteBuffer(const ByteBuffer &buf) { + static const char hex_digits[] = "0123456789ABCDEF"; + std::string str = buf.str(); + std::string output; + output.reserve(str.size() * 2); + for (unsigned char c : str) { + output.push_back(hex_digits[c >> 4]); + output.push_back(hex_digits[c & 15]); + } + return output; +} + +bool DumpRunInfo(const OpRunInfo &run_info, char *run_info_json, size_t run_info_len) { + if (run_info_json == nullptr) { + GE_LOGE("run_info buffer is null"); + return false; + } + + nlohmann::json json_obj; + json_obj["block_dim"] = run_info.block_dim; + json_obj["workspaces"] = run_info.workspaces; + json_obj["tiling_data"] = DumpByteBuffer(run_info.tiling_data); + json_obj["clear_atomic"] = run_info.clear_atomic; + + std::string str = json_obj.dump(); + if (str.size() >= run_info_len) { + GE_LOGE("runinfo too large. %zu/%zu", str.size(), run_info_len); + return false; + } + auto rc = memcpy_s(run_info_json, str.size() + 1, str.c_str(), str.size() + 1); + if (rc != EOK) { + return false; + } + return true; +} + +extern "C" int TbeOpTilingPyInterfaceEx2(const char *optype, const char *compile_info, const char *inputs, + const char *outputs, char *run_info_json, size_t run_info_len, + const char *compile_info_hash, uint64_t *elapse) { + if (optype == nullptr || compile_info == nullptr || inputs == nullptr || outputs == nullptr) { + GE_LOGE("optype/compile_info/inputs/outputs is null, %s, %s, %s, %s", optype, compile_info, inputs, outputs); + return 0; + } + + std::chrono::time_point before_tiling, after_tiling; + + std::string compile_info_str = compile_info; + TeOpParas op_params; + std::map> const_values; + try { + nlohmann::json inputs_json = nlohmann::json::parse(inputs); + nlohmann::json outputs_json = nlohmann::json::parse(outputs); + ParseShapeDescList(inputs_json, op_params.inputs); + ParseShapeDescList(outputs_json, op_params.outputs); + ParseConstTensorList(inputs_json, op_params.const_inputs, const_values); + } catch (...) { + GE_LOGE("Failed to parse json_str. %s, %s, %s", compile_info, inputs, outputs); + return 0; + } + + auto &interf = OpTilingRegistryInterf::RegisteredOpInterf(); + auto iter = interf.find(optype); + if (iter == interf.end()) { + iter = interf.find("AutoTiling"); + } + + if (iter == interf.end()) { + GE_LOGE("Optiling func not found. op_type:%s", optype); + return 0; + } + + GELOGI("Optiling func found, op_type:%s, func:[%s:%p]", optype, iter->first.c_str(), + iter->second.target()); + + OpCompileInfo op_compile_info{compile_info}; + if (compile_info_hash) { + op_compile_info.key = compile_info_hash; + } + + OpRunInfo run_info; + if (elapse) { + before_tiling = std::chrono::steady_clock::now(); + } + + bool rc = (iter->second)(op_params, op_compile_info, run_info); + + if (elapse) { + after_tiling = std::chrono::steady_clock::now(); + } + if (!rc) { + GE_LOGE("Optiling failed. op_type:%s", optype); + return 0; + } + + if (elapse) { + *elapse = std::chrono::duration_cast(after_tiling - before_tiling).count(); + *(elapse + 1) = last_op_tiling_perf; + last_op_tiling_perf = -1; + } + + GELOGI("Optiling succeed. op_type:%s", optype); + DumpRunInfo(run_info, run_info_json, run_info_len); + return 1; +} + +extern "C" int TbeOpTilingPyInterfaceEx(const char *optype, const char *compile_info, const char *inputs, + const char *outputs, char *run_info_json, size_t run_info_len, + uint64_t *elapse) { + return TbeOpTilingPyInterfaceEx2(optype, compile_info, inputs, outputs, run_info_json, run_info_len, + nullptr, elapse); +} + +extern "C" int TbeOpTilingPyInterface(const char *optype, const char *compile_info, const char *inputs, + const char *outputs, char *run_info_json, size_t run_info_len) { + return TbeOpTilingPyInterfaceEx(optype, compile_info, inputs, outputs, run_info_json, run_info_len, nullptr); +} + +extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info) { + ge::OpDescPtr op_desc = node.GetOpDesc(); + std::string op_type = op_desc->GetType(); + std::string op_name = op_desc->GetName(); + TeOpParas op_param; + op_param.op_type = op_type; + + GELOGI("Do optiling, op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + + auto inputs = op_desc->GetAllInputsDescPtr(); + auto outputs = op_desc->GetAllOutputsDescPtr(); + + bool bres = false; + bres = FeedTeOpTensorArg(inputs, op_param.inputs); + if (!bres) { + GE_LOGE("Do optiling, op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + bres = FeedTeOpTensorArg(outputs, op_param.outputs); + if (!bres) { + return ge::GRAPH_FAILED; + } + + FeedTeOpConstTensor(node, op_desc, op_param.const_inputs); + + auto &interf = OpTilingRegistryInterf::RegisteredOpInterf(); + auto iter = interf.find(op_type); + if (iter == interf.end()) { + iter = interf.find("AutoTiling"); + } + if (iter == interf.end()) { + GE_LOGE("Optiling func not found. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + OpCompileInfo op_compile_info; + bres = GetCompileInfo(op_desc, op_type.c_str(), op_name.c_str(), op_compile_info); + if (!bres) { + GE_LOGE("Failed to get compile_info, op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + GELOGI("Optiling func found, op_type:%s, op_name:%s, func:[%s:%p]", op_type.c_str(), op_name.c_str(), + iter->first.c_str(), iter->second.target()); + bool rc = (iter->second)(op_param, op_compile_info, run_info); + if (rc) { + GELOGI("Optiling succeed. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + } else { + GE_LOGE("Optiling failed. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + } + return rc ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; +} + +extern "C" ge::graphStatus OpAtomicCalculate(const ge::Node &node, OpRunInfo &run_info) { + ge::OpDescPtr op_desc = node.GetOpDesc(); + std::string op_type = "DynamicAtomicAddrClean"; + std::string op_name = op_desc->GetName(); + std::string origin_op_type = "DynamicAtomicAddrClean"; + TeOpParas op_param; + op_param.op_type = op_type; + + GELOGI("Do Atomic optiling. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + std::vector atomic_output_indices; + (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + if (atomic_output_indices.empty()) { + GE_LOGE("No ATOMIC_ATTR_OUTPUT_INDEX found, op_type:%s, op_name:%s", origin_op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + auto tensor = op_desc->MutableOutputDesc(atomic_output_indices[0]); + if (tensor == nullptr) { + GE_LOGE("Get MutableOutputDesc failed. op_type:%s, op_name:%s", origin_op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + int64_t clean_size = 0; + auto res = ge::TensorUtils::GetSize(*tensor, clean_size); + if (res != ge::GRAPH_SUCCESS) { + GE_LOGE("Get size of tensor desc failed. op_type:%s, op_name:%s", origin_op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + GELOGI("Atomic clean size: %ld, op_type:%s, op_name:%s", clean_size, origin_op_type.c_str(), op_name.c_str()); + op_param.const_inputs.emplace("workspace_size", + TeConstTensorData(nullptr, static_cast(clean_size), ge::Tensor())); + + auto &interf = OpTilingRegistryInterf::RegisteredOpInterf(); + auto iter = interf.find(op_type); + if (iter == interf.end()) { + GE_LOGE("Atomic optiling func not found. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + ge::NodePtr atomic_clean_node = nullptr; + atomic_clean_node = op_desc->TryGetExtAttr("atomic_clean_node_ptr", atomic_clean_node); + if (atomic_clean_node == nullptr) { + GE_LOGE("This node has no atomice node. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + ge::OpDescPtr atomic_op_desc = atomic_clean_node->GetOpDesc(); + if (atomic_op_desc == nullptr) { + GE_LOGE("Failed to get op desc from node. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + OpCompileInfo op_compile_info; + bool bres = GetCompileInfo(atomic_op_desc, op_type.c_str(), op_name.c_str(), op_compile_info); + if (!bres) { + GE_LOGE("Failed to get compile_info, op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + return ge::GRAPH_FAILED; + } + + bool rc = (iter->second)(op_param, op_compile_info, run_info); + if (rc) { + GELOGI("Atomic optiling succeed. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + } else { + GE_LOGE("Atomic optiling failed. op_type:%s, op_name:%s", op_type.c_str(), op_name.c_str()); + } + + return rc ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; +} +} // namespace optiling diff --git a/metadef/register/op_tiling_registry.cpp b/metadef/register/op_tiling_registry.cpp new file mode 100644 index 00000000..4466a135 --- /dev/null +++ b/metadef/register/op_tiling_registry.cpp @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/op_tiling_registry.h" + +#include +#include "framework/common/debug/ge_log.h" + +namespace optiling { + +thread_local int64_t last_op_tiling_perf = -1; + +std::map &OpTilingRegistryInterf::RegisteredOpInterf() { + static std::map interf; + return interf; +} + +OpTilingRegistryInterf::OpTilingRegistryInterf(std::string op_type, OpTilingFunc func) { + auto &interf = RegisteredOpInterf(); + interf.emplace(op_type, func); + GELOGI("Register tiling function: op_type:%s, funcPointer:%p, registered count:%zu", op_type.c_str(), + func.target(), interf.size()); +} + +size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len) { + size_t nread = 0; + size_t rn = 0; + do { + rn = buf.readsome(dest + nread, dest_len - nread); + nread += rn; + } while (rn > 0 && dest_len > nread); + + return nread; +} +} // namespace optiling diff --git a/metadef/register/ops_kernel_builder_registry.cc b/metadef/register/ops_kernel_builder_registry.cc new file mode 100644 index 00000000..ef55590d --- /dev/null +++ b/metadef/register/ops_kernel_builder_registry.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "register/ops_kernel_builder_registry.h" +#include "graph/debug/ge_log.h" + +namespace ge { +OpsKernelBuilderRegistry::~OpsKernelBuilderRegistry() { + for (auto &it : kernel_builders_) { + GELOGW("%s was not unregistered", it.first.c_str()); + // to avoid coredump when unregister is not called when so was close + // this is called only when app is shutting down, so no release would be leaked + new (std::nothrow) std::shared_ptr(it.second); + } +} +void OpsKernelBuilderRegistry::Register(const string &lib_name, const OpsKernelBuilderPtr &instance) { + auto it = kernel_builders_.emplace(lib_name, instance); + if (it.second) { + GELOGI("Done registering OpsKernelBuilder successfully, kernel lib name = %s", lib_name.c_str()); + } else { + GELOGW("OpsKernelBuilder already registered. kernel lib name = %s", lib_name.c_str()); + } +} + +void OpsKernelBuilderRegistry::UnregisterAll() { + kernel_builders_.clear(); + GELOGI("All builders are unregistered"); +} + +void OpsKernelBuilderRegistry::Unregister(const string &lib_name) { + kernel_builders_.erase(lib_name); + GELOGI("OpsKernelBuilder of %s is unregistered", lib_name.c_str()); +} + +const std::map &OpsKernelBuilderRegistry::GetAll() const { + return kernel_builders_; +} +OpsKernelBuilderRegistry &OpsKernelBuilderRegistry::GetInstance() { + static OpsKernelBuilderRegistry instance; + return instance; +} + +OpsKernelBuilderRegistrar::OpsKernelBuilderRegistrar(const string &kernel_lib_name, + OpsKernelBuilderRegistrar::CreateFn fn) + : kernel_lib_name_(kernel_lib_name) { + GELOGI("To register OpsKernelBuilder, kernel lib name = %s", kernel_lib_name.c_str()); + std::shared_ptr builder; + if (fn != nullptr) { + builder.reset(fn()); + if (builder == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to create OpsKernelBuilder, kernel lib name = %s", kernel_lib_name.c_str()); + } + } else { + GELOGE(INTERNAL_ERROR, "Creator is nullptr. kernel lib name = %s", kernel_lib_name.c_str()); + } + + // May add empty ptr, so that error can be found afterward + OpsKernelBuilderRegistry::GetInstance().Register(kernel_lib_name, builder); +} + +OpsKernelBuilderRegistrar::~OpsKernelBuilderRegistrar() { + GELOGI("OpsKernelBuilderRegistrar destroyed. KernelLibName = %s", kernel_lib_name_.c_str()); + OpsKernelBuilderRegistry::GetInstance().Unregister(kernel_lib_name_); +} +} // namespace ge diff --git a/metadef/register/proto/om.proto b/metadef/register/proto/om.proto new file mode 100644 index 00000000..e15e5f80 --- /dev/null +++ b/metadef/register/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/metadef/register/proto/task.proto b/metadef/register/proto/task.proto new file mode 100644 index 00000000..d0c09840 --- /dev/null +++ b/metadef/register/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/metadef/register/proto/tensorflow/attr_value.proto b/metadef/register/proto/tensorflow/attr_value.proto new file mode 100644 index 00000000..1cc67d62 --- /dev/null +++ b/metadef/register/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/metadef/register/proto/tensorflow/function.proto b/metadef/register/proto/tensorflow/function.proto new file mode 100644 index 00000000..075897c6 --- /dev/null +++ b/metadef/register/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/metadef/register/proto/tensorflow/graph.proto b/metadef/register/proto/tensorflow/graph.proto new file mode 100644 index 00000000..d639a7d6 --- /dev/null +++ b/metadef/register/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/metadef/register/proto/tensorflow/graph_library.proto b/metadef/register/proto/tensorflow/graph_library.proto new file mode 100644 index 00000000..e393d38d --- /dev/null +++ b/metadef/register/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/metadef/register/proto/tensorflow/node_def.proto b/metadef/register/proto/tensorflow/node_def.proto new file mode 100644 index 00000000..b9bc97ee --- /dev/null +++ b/metadef/register/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/metadef/register/proto/tensorflow/op_def.proto b/metadef/register/proto/tensorflow/op_def.proto new file mode 100644 index 00000000..3485d045 --- /dev/null +++ b/metadef/register/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/metadef/register/proto/tensorflow/resource_handle.proto b/metadef/register/proto/tensorflow/resource_handle.proto new file mode 100644 index 00000000..a3452351 --- /dev/null +++ b/metadef/register/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/metadef/register/proto/tensorflow/tensor.proto b/metadef/register/proto/tensorflow/tensor.proto new file mode 100644 index 00000000..d0a4d024 --- /dev/null +++ b/metadef/register/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/metadef/register/proto/tensorflow/tensor_shape.proto b/metadef/register/proto/tensorflow/tensor_shape.proto new file mode 100644 index 00000000..4225a2e3 --- /dev/null +++ b/metadef/register/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/metadef/register/proto/tensorflow/types.proto b/metadef/register/proto/tensorflow/types.proto new file mode 100644 index 00000000..ba7a72b3 --- /dev/null +++ b/metadef/register/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/metadef/register/proto/tensorflow/versions.proto b/metadef/register/proto/tensorflow/versions.proto new file mode 100644 index 00000000..48061218 --- /dev/null +++ b/metadef/register/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/metadef/register/register.cpp b/metadef/register/register.cpp new file mode 100644 index 00000000..08a82366 --- /dev/null +++ b/metadef/register/register.cpp @@ -0,0 +1,1083 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "external/register/register.h" +#include +#include "debug/ge_util.h" +#include "debug/ge_op_types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "graph/debug/ge_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "proto/tensorflow/attr_value.pb.h" +#include "proto/tensorflow/node_def.pb.h" +#include "register/auto_mapping_util.h" +#include "register/op_registry.h" +#include "graph/graph.h" + +using namespace domi::tensorflow; +namespace domi { +/*lint -e1073*/ +namespace { +const std::string kDefaultFormat = "ND"; +const std::string kSrcFormat = "src_format"; +const std::string kDstFormat = "dst_format"; +const std::string kDataFormat = "data_format"; +const std::string kTfInputDesc = "input_tensor_desc"; +const std::string kTfOutputDesc = "output_tensor_desc"; +const std::string kFuncNameKey = "name"; + +struct DynamicInfo { + DynamicType type; + uint32_t inset_index; + uint32_t tensor_num; + DynamicInfo() : type(kInvalid), inset_index(0), tensor_num(0) {} + DynamicInfo(DynamicType type, uint32_t index, uint32_t num) : type(type), inset_index(index), tensor_num(num) {} +}; + +std::set GetSubgraphAttrNames(const ge::Operator &op) { + if (op.GetSubgraphNamesCount() == 0) { + return std::set(); + } + auto subgraph_names = op.GetSubgraphNames(); + return std::set(subgraph_names.begin(), subgraph_names.end()); +} + +/// there are two forms to represent functions in TF: +/// case 1(subgraph of a `if` node) normal subgraph: +/// attr { +/// key: "else_branch" +/// value { +/// func { +/// name: "cond_false_9" +/// } +/// } +/// } +/// +/// case 2(subgraph of a `case` node) dynamic subgraph: +/// attr { +/// key: "branches" +/// value { +/// list { +/// func { +/// name: "two_J6Sc96RZs5g" +/// } +/// func { +/// name: "three_3pYv7KFNs2M" +/// } +/// func { +/// name: "four_MdtG6T4LHxA" +/// } +/// } +/// } +/// } +/// \param func_attr +/// \param op_desc +/// \return +Status AutoMappingFunction(const std::pair &func_attr, + std::shared_ptr &op_desc) { + switch (func_attr.second.value_case()) { + case domi::tensorflow::AttrValue::kFunc: + { + const auto &func_signature = func_attr.second.func().name(); + auto ret = ge::OpDescUtils::SetSubgraphInstanceName(func_attr.first, func_signature, op_desc); + if (ret != ge::GRAPH_SUCCESS) { + GE_LOGE("Failed to set subgraph instance %s for node %s type %s, instance name %s", + func_attr.first.c_str(), op_desc->GetName().c_str(), + op_desc->GetType().c_str(), func_signature.c_str()); + return FAILED; + } + break; + } + case domi::tensorflow::AttrValue::kList: + { + uint32_t i = 0; + for (auto &dyn_func_attr : func_attr.second.list().func()) { + const auto &func_signature = dyn_func_attr.name(); + auto subgraph_name = func_attr.first + std::to_string(i++); + auto ret = op_desc->AddSubgraphName(subgraph_name); + if (ret != ge::GRAPH_SUCCESS) { + GE_LOGE("Failed to add subgraph name %s to node %s type %s", + subgraph_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return FAILED; + } + ret = ge::OpDescUtils::SetSubgraphInstanceName(subgraph_name, func_signature, op_desc); + if (ret != ge::GRAPH_SUCCESS) { + GE_LOGE("Failed to set dynamic subgraph instance %s for node %s type %s, instance name %s", + func_attr.first.c_str(), op_desc->GetName().c_str(), + op_desc->GetType().c_str(), func_signature.c_str()); + return FAILED; + } + } + break; + } + default: + GE_LOGE("Unexpected attr value type %d for func", static_cast(func_attr.second.value_case())); + return FAILED; + } + return SUCCESS; +} + +Status CheckDynamicInfo(const vector &dynamic_name_attr_value) { + for (const auto &dynamic_info : dynamic_name_attr_value) { + if (dynamic_info.port_name_len == 0 || dynamic_info.port_name_len > kMaxNameLength || + dynamic_info.attr_name_len == 0 || dynamic_info.attr_name_len > kMaxNameLength) { + GELOGE(PARAM_INVALID, "Invalid Param, port_name_len[%ld], attr_name_len[%ld].", + dynamic_info.port_name_len, dynamic_info.attr_name_len); + return PARAM_INVALID; + } + + int64_t port_name_len = strlen(dynamic_info.port_name); + if (dynamic_info.port_name == nullptr || port_name_len != dynamic_info.port_name_len) { + GELOGE(PARAM_INVALID, "Invalid Param, port_name[%s], port_name_len[%ld]", + dynamic_info.port_name, dynamic_info.port_name_len); + return PARAM_INVALID; + } + + int64_t attr_name_len = strlen(dynamic_info.attr_name); + if (dynamic_info.attr_name == nullptr || attr_name_len != dynamic_info.attr_name_len) { + GELOGE(PARAM_INVALID, "Invalid Param, attr_name[%s], attr_name_len[%ld]", + dynamic_info.attr_name, dynamic_info.attr_name_len); + return PARAM_INVALID; + } + } + + return SUCCESS; +} + +Status GetDynamicTensorNum(const std::shared_ptr &op_desc, const string &attr_name, uint32_t &tensor_num) { + GE_CHECK_NOTNULL(op_desc); + + ge::GeAttrValue attr_value; + ge::graphStatus ret = op_desc->GetAttr(attr_name, attr_value); + if (ret != SUCCESS) { + GELOGE(FAILED, "Op[%s] get attr name[%s] value failed.", op_desc->GetName().c_str(), attr_name.c_str()); + return FAILED; + } + + ge::GeAttrValue::ValueType value_type = attr_value.GetValueType(); + switch (value_type) { + case ge::GeAttrValue::VT_LIST_DATA_TYPE: { + vector vec_d; + (void)ge::AttrUtils::GetListDataType(op_desc, attr_name, vec_d); + tensor_num = static_cast(vec_d.size()); + break; + } + case ge::GeAttrValue::VT_INT: { + (void)ge::AttrUtils::GetInt(op_desc, attr_name, tensor_num); + break; + } + default: + GELOGI("Default other value type: %d", static_cast(value_type)); + break; + } + + return SUCCESS; +} + +Status UpdateDynamicInputOutPutIndex(const std::shared_ptr &op_desc, + const vector &dynamic_name_attrs, map &port_dynamic_info) { + GE_CHECK_NOTNULL(op_desc); + for (const auto &dynamic_name_attr : dynamic_name_attrs) { + const std::string attr_name = dynamic_name_attr.attr_name; + uint32_t dynamic_tensor_num = 0; + if (op_desc->HasAttr(attr_name)) { + if (GetDynamicTensorNum(op_desc, attr_name, dynamic_tensor_num) != SUCCESS) { + GELOGE(FAILED, "Get dynamic tensor num failed."); + return FAILED; + } + } else { + GELOGW("In op %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), attr_name.c_str()); + continue; + } + GELOGI("In Op %s dynamic attr [%s] is exist, tensor num: %u.", op_desc->GetName().c_str(), attr_name.c_str(), + dynamic_tensor_num); + port_dynamic_info[dynamic_name_attr.port_name] = DynamicInfo(dynamic_name_attr.type, 0, dynamic_tensor_num); + } + + const vector register_input_names = op_desc->GetRegisterInputName(); + uint32_t input_index = 0; + uint32_t input_increment = 0; + for (const auto &input_name : register_input_names) { + if (port_dynamic_info.find(input_name) != port_dynamic_info.end()) { + port_dynamic_info[input_name].inset_index = input_index + input_increment; + uint32_t tensor_num = port_dynamic_info[input_name].tensor_num; + input_increment += tensor_num > 0 ? tensor_num - 1 : 0; + GELOGI("Dynamic input name[%s] insert index: %u, tensor num: %u, op proto index: %u", input_name.c_str(), + port_dynamic_info[input_name].inset_index, tensor_num, input_index); + } + input_index++; + } + const vector register_output_names = op_desc->GetRegisterOutputName(); + uint32_t output_index = 0; + uint32_t out_increment = 0; + for (const auto &output_name : register_output_names) { + if (port_dynamic_info.find(output_name) != port_dynamic_info.end()) { + port_dynamic_info[output_name].inset_index = output_index + out_increment; + uint32_t tensor_num = port_dynamic_info[output_name].tensor_num; + out_increment += tensor_num > 0 ? tensor_num - 1 : 0; + GELOGI("Dynamic output name[%s] insert index: %u, tensor num: %u, op proto index: %u", output_name.c_str(), + port_dynamic_info[output_name].inset_index, tensor_num, output_index); + } + output_index++; + } + return SUCCESS; +} + +Status SetOpdescInputOutputFormat(std::shared_ptr &op_desc) { + GE_CHECK_NOTNULL(op_desc); + + auto inputDescsPtr = op_desc->GetAllInputsDescPtr(); + auto outputDescsPtr = op_desc->GetAllOutputsDescPtr(); + + string src_data_format = kDefaultFormat; + string dst_data_format = kDefaultFormat; + if (op_desc->HasAttr(kSrcFormat)) { + (void)ge::AttrUtils::GetStr(op_desc, kSrcFormat, src_data_format); + } + if (op_desc->HasAttr(kDstFormat)) { + (void)ge::AttrUtils::GetStr(op_desc, kDstFormat, dst_data_format); + } + if (op_desc->HasAttr(kDataFormat)) { + (void)ge::AttrUtils::GetStr(op_desc, kDataFormat, src_data_format); + dst_data_format = src_data_format; + } + ge::Format format = ge::TypeUtils::DataFormatToFormat(src_data_format); + for (auto inputDescPtr : inputDescsPtr) { + inputDescPtr->SetOriginFormat(format); + inputDescPtr->SetFormat(format); + } + format = ge::TypeUtils::DataFormatToFormat(dst_data_format); + for (auto outputDescPtr : outputDescsPtr) { + outputDescPtr->SetOriginFormat(format); + outputDescPtr->SetFormat(format); + } + return SUCCESS; +} +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFnDynamic( + const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, int in_pos, int out_pos) { + // 1. automapping for parser + std::shared_ptr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(op_src); + Status ret = AutoMappingFn(op_src, op); + if (ret != SUCCESS) { + GE_LOGE("Op: %s call auto mapping function failed.", op_desc->GetName().c_str()); + return FAILED; + } + + GELOGI("op[%s] call auto mapping function success.", op_desc->GetName().c_str()); + + if (dynamic_name_attr_value.size() > 2) { // attr value size should be less than 2 + GE_LOGE("attr set size [%zu] should be less than 2.", dynamic_name_attr_value.size()); + return FAILED; + } + + // add dynamci input and output + const NodeDef *node = reinterpret_cast(op_src); + for (auto it : dynamic_name_attr_value) { + std::string flag = it.first; + std::pair name_value = it.second; + std::string dynamic_name = name_value.first; + std::string attr_name = name_value.second; + + tensorflow::AttrValue attr_num; + int32_t dynamic_tensor_num = 0; + if (!(ge::AutoMappingUtil::FindAttrValue(node, attr_name, attr_num))) { + GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", node->name().c_str(), attr_name.c_str()); + } + + if (attr_num.has_list()) { + dynamic_tensor_num = attr_num.list().type_size(); + } else { + dynamic_tensor_num = static_cast(attr_num.i()); + } + + if (dynamic_tensor_num <= 0) { + GELOGW("In NodeDef %s dynamic num %d is less than 0.", node->name().c_str(), dynamic_tensor_num); + continue; + } + + GELOGI("In NodeDef %s dynamic attr [%s] is exist: %d.", node->name().c_str(), attr_name.c_str(), + dynamic_tensor_num); + + if (flag == "in") { + bool is_pushback = (in_pos == -1); + (void)op_desc->AddDynamicInputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); + ge::AttrUtils::SetInt(op_desc, DYNAMIC_INPUT_TD_NUM(dynamic_name), dynamic_tensor_num); + GELOGI("In NodeDef %s add dynamic input[%d]", node->name().c_str(), dynamic_tensor_num); + } else if (flag == "out") { + bool is_pushback = (out_pos == -1); + (void)op_desc->AddDynamicOutputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); + ge::AttrUtils::SetInt(op_desc, DYNAMIC_OUTPUT_TD_NUM(dynamic_name), dynamic_tensor_num); + GELOGI("In NodeDef %s add dynamic output[%d]", node->name().c_str(), dynamic_tensor_num); + } + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, + ge::Operator &op, const vector &dynamic_name_attr_value) { + // 1. auto mapping for parser + std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op); + GE_CHECK_NOTNULL(op_desc_dst); + + Status ret = AutoMappingByOpFn(op_src, op); + if (ret != SUCCESS) { + GELOGE(ret, "Op[%s] call auto mapping function failed.", op_desc_dst->GetName().c_str()); + return FAILED; + } + + GELOGI("Op[%s] call auto mapping function success.", op_desc_dst->GetName().c_str()); + // 2. check dynamic input output info; + if (CheckDynamicInfo(dynamic_name_attr_value) != SUCCESS) { + GELOGE(FAILED, "Check dynamic info param failed."); + return FAILED; + } + // 3. update dynamic input output index by tensor num; + map port_dynamic_info; + if (UpdateDynamicInputOutPutIndex(op_desc_dst, dynamic_name_attr_value, port_dynamic_info) != SUCCESS) { + GELOGE(FAILED, "Update dynamic input output index failed."); + return FAILED; + } + // 4. sort map by port name insert index. + vector> port_dynamic_info_vec(port_dynamic_info.begin(), port_dynamic_info.end()); + std::sort(port_dynamic_info_vec.begin(), port_dynamic_info_vec.end(), + [](const pair &p1, const pair &p2) + { return p1.second.inset_index < p2.second.inset_index; }); + // 5. add dynamic input and output + for (const auto &dynamic_info : port_dynamic_info_vec) { + string port_name = dynamic_info.first; + DynamicType dynamic_type = dynamic_info.second.type; + uint32_t insert_index = dynamic_info.second.inset_index; + uint32_t tensor_num = dynamic_info.second.tensor_num; + if (tensor_num == 0) { + GELOGW("In op[%s] tensor num of port[%s] is equal 0.", op_desc_dst->GetName().c_str(), port_name.c_str()); + continue; + } + if (dynamic_type == kInput) { + (void)op_desc_dst->AddInputDescMiddle(port_name, tensor_num, insert_index); + GELOGI("Op[%s] add dynamic input[%u]", op_desc_dst->GetName().c_str(), tensor_num); + } else if (dynamic_type == kOutput) { + (void)op_desc_dst->AddOutputDescMiddle(port_name, tensor_num, insert_index); + GELOGI("Op[%s] add dynamic output[%u]", op_desc_dst->GetName().c_str(), tensor_num); + } + } + + return SUCCESS; +} + +// Convert tensorflow property to ge property +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFn(const Message *op_src, ge::Operator &op) { + std::shared_ptr op_dst = ge::OpDescUtils::GetOpDescFromOperator(op); + // Analysis of tensorflow operator parameters based on key value + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op_dst); + + auto subgraph_attr_names = GetSubgraphAttrNames(op); + const NodeDef *node_src = reinterpret_cast(op_src); + op_dst->SetName(node_src->name()); + for (const auto &attr_pair : node_src->attr()) { + if (attr_pair.first == kTfInputDesc || attr_pair.first == kTfOutputDesc) { + continue; + } + if (subgraph_attr_names.count(attr_pair.first) > 0) { + auto ret = AutoMappingFunction(attr_pair, op_dst); + if (ret != SUCCESS) { + return ret; + } + } else { + ge::AutoMappingUtil::ConvertValue(attr_pair.first, attr_pair.second, op_dst); + } + } + + Status ret = SetOpdescInputOutputFormat(op_dst); + if (ret != SUCCESS) { + GELOGE(FAILED, "Set op[%s] desc input output format failed.", op_dst->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingByOpFn(const ge::Operator &op_src, + ge::Operator &op) { + std::shared_ptr op_desc_src = ge::OpDescUtils::GetOpDescFromOperator(op_src); + std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op); + GE_CHECK_NOTNULL(op_desc_src); + GE_CHECK_NOTNULL(op_desc_dst); + + op_desc_dst->SetName(op_desc_src->GetName()); + const auto subgraph_name_indexs = op_desc_src->GetSubgraphNameIndexes(); + for (const auto &subgraph_name_index : subgraph_name_indexs) { + auto ret = op_desc_dst->AddSubgraphName(subgraph_name_index.first); + if (ret != ge::GRAPH_SUCCESS) { + GELOGW("Subgraph with name %s for node %s type %s has already added.", + subgraph_name_index.first.c_str(), op_desc_dst->GetName().c_str(), op_desc_dst->GetType().c_str()); + } + } + + const auto subgraph_instance_names = op_desc_src->GetSubgraphInstanceNames(); + uint32_t index = 0; + for (const auto &subgraph_instance_name : subgraph_instance_names) { + auto ret = op_desc_dst->SetSubgraphInstanceName(index, subgraph_instance_name); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to add subgraph instance name: %s, index: %u, for node %s type %s.", + subgraph_instance_name.c_str(), index, op_desc_dst->GetType().c_str(), op_desc_dst->GetName().c_str()); + return FAILED; + } + index++; + } + + map attr_values = op_desc_src->GetAllAttrs(); + for (auto &attr_value : attr_values) { + ge::AutoMappingUtil::CopyAttrValue(attr_value.first, attr_value.second, op_desc_src, op_desc_dst); + } + + Status ret = SetOpdescInputOutputFormat(op_desc_dst); + if (ret != SUCCESS) { + GELOGE(FAILED, "Set op desc Input output failed."); + return FAILED; + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output) { + GE_CHECK_NOTNULL(input); + GE_CHECK_NOTNULL(output); + return AutoMappingSubgraphIndex(graph, + [&](int i, int &o) -> Status { + o = input(i); + return SUCCESS; + }, + [&](int i, int &o) -> Status { + o = output(i); + return SUCCESS; + }); +} + +namespace { + const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; + std::vector> FindNodesByType(const ge::ComputeGraphPtr &graph, const std::string &type) { + std::vector> nodes; + for (const auto &node : graph->GetDirectNode()) { + GELOGI("Find node %s, node type is %s.", type.c_str(), node->GetOpDesc()->GetType().c_str()); + if (node->GetOpDesc()->GetType() == type) { + nodes.push_back(node); + continue; + } + if (node->GetOpDesc()->GetType() == "FrameworkOp") { + std::string original_type; + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type)) { + // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. + continue; + } + if (original_type == type) { + nodes.push_back(node); + } + } + } + return nodes; + } +} + +Status AutoMappingSubgraphOutput(const ge::ComputeGraphPtr &graph, + const std::function &output) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(output); + const auto &output_node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); + if (output_node == nullptr) { // Graph from parser no NetOutput. + return SUCCESS; + } + + const auto &op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { + int parent_index = -1; + auto ret = output(index, parent_index); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to get parent index for net output index %ld, error code %u", index, ret); + return FAILED; + } + + GELOGI("Generate subgraph output map for subgraph %s, index %ld, parent node index %d", + graph->GetName().c_str(), index, parent_index); + if (parent_index == -1) { + continue; + } + + ge::GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(tensor); + if (!ge::AttrUtils::SetInt(tensor, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to add parent node index for graph %s", graph->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output) { + GE_CHECK_NOTNULL(input); + GE_CHECK_NOTNULL(output); + + auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + auto nodes = FindNodesByType(compute_graph, "Data"); + for (size_t i = 0; i < nodes.size(); ++i) { + int parent_index = -1; + int index = -1; + if (!ge::AttrUtils::GetInt(nodes[i]->GetOpDesc(), "index", index)) { + GELOGE(FAILED, "Failed to get index from data[%d], failed to get the attr", i); + return FAILED; + } + GELOGI("Get index %d from data[%d]", index, i); + auto ret = input(index, parent_index); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to get parent index from data index %d, error code %u", i, ret); + return FAILED; + } + if (!ge::AttrUtils::SetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to add parent node index for node %s", nodes[i]->GetName().c_str()); + return FAILED; + } + GELOGI("Generate subgraph input map for subgraph %s, data index %zu, parent node index %d", + graph.GetName().c_str(), i, parent_index); + + } + + nodes = FindNodesByType(compute_graph, "_Retval"); + for (auto &retval : nodes) { + int64_t index = -1; + if (!ge::AttrUtils::GetInt(retval->GetOpDesc(), "retval_index", index)) { + GELOGE(FAILED, "Failed to get parent index from retval index %ld, failed to get the attr", index); + return FAILED; + } + int parent_index = -1; + auto ret = output(index, parent_index); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to get parent index from retval index %ld, error code %u", index, ret); + return FAILED; + } + if (!ge::AttrUtils::SetInt(retval->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to add parent node index for node %s", retval->GetName().c_str()); + return FAILED; + } + GELOGI("Generate subgraph output map for subgraph %s, retval index %ld, parent node index %d", + graph.GetName().c_str(), index, parent_index); + } + + return nodes.empty() ? AutoMappingSubgraphOutput(compute_graph, output) : SUCCESS; +} + +OpReceiver::OpReceiver(OpRegistrationData ®_data) { OpRegistry::Instance()->registrationDatas.push_back(reg_data); } + +class OpRegistrationDataImpl { + public: + OpRegistrationDataImpl() = default; + ~OpRegistrationDataImpl() = default; + explicit OpRegistrationDataImpl(const std::string &om_optype); + + domi::FrameworkType fmk_type_; + std::set ori_optype_set_; // OP type in the original model, there may be multiple + std::string om_optype_; // OP type in OM model + domi::ImplyType imply_type_; // execution type + ParseParamFunc parseParamFn_; // parseParam function + ParseParamByOpFunc parse_param_by_op_fn_; // parse param by op function + FusionParseParamFunc fusionParseParamFn_; // fusion parseParam function + FusionParseParamByOpFunc fusion_parse_param_by_op_fn_; // fusion parseParam by op function + ParseSubgraphFunc parse_subgraph_post_fn_; // a function called after the subgraph was generated + ParseSubgraphFuncV2 parse_subgraph_post_fn_v2_; // a function called after the subgraph was generated + std::vector remove_input_configure_vec_; + ParseOpToGraphFunc parse_op_to_graph_fn_; +}; + +OpRegistrationDataImpl::OpRegistrationDataImpl(const std::string &om_optype) + : fmk_type_(FRAMEWORK_RESERVED), + om_optype_(om_optype), + imply_type_(domi::ImplyType::BUILDIN), + parseParamFn_(nullptr), + parse_param_by_op_fn_(nullptr), + fusionParseParamFn_(nullptr), + fusion_parse_param_by_op_fn_(nullptr), + parse_subgraph_post_fn_(nullptr), + parse_subgraph_post_fn_v2_(nullptr), + parse_op_to_graph_fn_(nullptr) {} + +OpRegistrationData::~OpRegistrationData() = default; + +OpRegistrationData::OpRegistrationData(const std::string &om_optype) { + impl_ = ComGraphMakeShared(om_optype); + if (impl_ == nullptr) { + GELOGW("OpRegistrationDataImpl make shared failed!"); + } +} + +OpRegistrationData::OpRegistrationData(const char *om_op_type) { + std::string op_type; + if (om_op_type != nullptr) { + op_type = om_op_type; + } + impl_ = ComGraphMakeShared(op_type); + if (impl_ == nullptr) { + GELOGW("OpRegistrationDataImpl make shared failed!"); + } +} + +std::string OpRegistrationData::GetOmOptype() const { + if (impl_ != nullptr) { + return impl_->om_optype_; + } + return ""; +} + +Status OpRegistrationData::GetOmOptype(ge::AscendString &om_op_type) const { + if (impl_ != nullptr) { + om_op_type = ge::AscendString(impl_->om_optype_.c_str()); + } + return SUCCESS; +} + +OpRegistrationData &OpRegistrationData::FrameworkType(const domi::FrameworkType &fmk_type) { + if (impl_ != nullptr) { + impl_->fmk_type_ = fmk_type; + } + return *this; +} + +domi::FrameworkType OpRegistrationData::GetFrameworkType() const { + if (impl_ != nullptr) { + return impl_->fmk_type_; + } + return FRAMEWORK_RESERVED; +} + +OpRegistrationData &OpRegistrationData::OriginOpType(const std::initializer_list &ori_optype_list) { + if (impl_ != nullptr) { + for (auto ori_optype : ori_optype_list) { + (void)impl_->ori_optype_set_.insert(ori_optype); + } + } + return *this; +} + +OpRegistrationData &OpRegistrationData::OriginOpType(const std::vector &ori_op_type_list) { + if (impl_ != nullptr) { + for (auto &ori_op_type : ori_op_type_list) { + std::string tmp_ori_op_type; + if (ori_op_type.GetString() != nullptr) { + tmp_ori_op_type = ori_op_type.GetString(); + } + (void)impl_->ori_optype_set_.insert(tmp_ori_op_type); + } + } + return *this; +} + +OpRegistrationData &OpRegistrationData::OriginOpType(const std::string &ori_optype) { + if (impl_ != nullptr) { + (void)impl_->ori_optype_set_.insert(ori_optype); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::OriginOpType(const char *ori_op_type) { + if (impl_ != nullptr) { + std::string tmp_ori_op_type; + if (ori_op_type != nullptr) { + tmp_ori_op_type = ori_op_type; + } + (void)impl_->ori_optype_set_.insert(tmp_ori_op_type); + } + return *this; +} + +std::set OpRegistrationData::GetOriginOpTypeSet() const { + std::set ori_optype_set; + if (impl_ != nullptr) { + return impl_->ori_optype_set_; + } + return ori_optype_set; +} + +Status OpRegistrationData::GetOriginOpTypeSet(std::set &ori_op_type) const { + std::set ori_op_type_set; + if (impl_ != nullptr) { + ori_op_type_set = impl_->ori_optype_set_; + } + for (auto &op_type : ori_op_type_set) { + ori_op_type.insert(ge::AscendString(op_type.c_str())); + } + return SUCCESS; +} + +OpRegistrationData &OpRegistrationData::ParseParamsFn(const ParseParamFunc &parseParamFn) { + if (impl_ != nullptr) { + impl_->parseParamFn_ = parseParamFn; + } + return *this; +} + +ParseParamFunc OpRegistrationData::GetParseParamFn() const { + if (impl_ != nullptr) { + return impl_->parseParamFn_; + } + return nullptr; +} + +OpRegistrationData &OpRegistrationData::ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn) { + if (impl_ != nullptr) { + impl_->parse_param_by_op_fn_ = parse_param_by_op_fn; + } + return *this; +} + +ParseParamByOpFunc OpRegistrationData::GetParseParamByOperatorFn() const { + if (impl_ != nullptr) { + return impl_->parse_param_by_op_fn_; + } + return nullptr; +} + +OpRegistrationData &OpRegistrationData::FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn) { + if (impl_ != nullptr) { + impl_->fusionParseParamFn_ = fusionParseParamFn; + } + return *this; +} + +FusionParseParamFunc OpRegistrationData::GetFusionParseParamFn() const { + if (impl_ != nullptr) { + return impl_->fusionParseParamFn_; + } + return nullptr; +} + +OpRegistrationData &OpRegistrationData::FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn) { + if (impl_ != nullptr) { + impl_->fusion_parse_param_by_op_fn_ = fusion_parse_param_fn; + } + return *this; +} + +FusionParseParamByOpFunc OpRegistrationData::GetFusionParseParamByOpFn() const { + if (impl_ != nullptr) { + return impl_->fusion_parse_param_by_op_fn_; + } + return nullptr; +} + +OpRegistrationData &OpRegistrationData::ImplyType(const domi::ImplyType &imply_type) { + if (impl_ != nullptr) { + impl_->imply_type_ = imply_type; + } + return *this; +} + +domi::ImplyType OpRegistrationData::GetImplyType() const { + domi::ImplyType imply_type = domi::ImplyType::BUILDIN; + if (impl_ != nullptr) { + return impl_->imply_type_; + } + return imply_type; +} + +OpRegistrationData &OpRegistrationData::DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue) { + if (impl_ != nullptr) { + struct RemoveInputConfigure registerStu; + registerStu.inputIdx = inputIdx; + registerStu.attrName = attrName; + registerStu.moveType = OMG_REMOVE_TYPE_WITH_COND; + registerStu.attrValue = attrValue; + impl_->remove_input_configure_vec_.push_back(registerStu); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::DelInputWithCond(int input_idx, const char *attr_name, bool attr_value) { + std::string tmp_attr_name; + if (attr_name != nullptr) { + tmp_attr_name = attr_name; + } + if (impl_ != nullptr) { + struct RemoveInputConfigure registerStu; + registerStu.inputIdx = input_idx; + registerStu.attrName = tmp_attr_name; + registerStu.moveType = OMG_REMOVE_TYPE_WITH_COND; + registerStu.attrValue = attr_value; + impl_->remove_input_configure_vec_.push_back(registerStu); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::InputReorderVector(const vector &input_order) { + if (impl_ != nullptr) { + struct RemoveInputConfigure register_input; + register_input.inputIdx = 0; + register_input.input_order = input_order; + register_input.moveType = OMG_INPUT_REORDER; + impl_->remove_input_configure_vec_.push_back(register_input); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::DelInputWithOriginalType(int input_idx, const std::string &ori_type) { + if (impl_ != nullptr) { + struct RemoveInputConfigure register_input; + register_input.inputIdx = input_idx; + register_input.originalType = ori_type; + register_input.moveType = OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE; + impl_->remove_input_configure_vec_.push_back(register_input); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::DelInputWithOriginalType(int input_idx, const char *ori_type) { + std::string tmp_ori_type; + if (ori_type != nullptr) { + tmp_ori_type = ori_type; + } + if (impl_ != nullptr) { + struct RemoveInputConfigure register_input; + register_input.inputIdx = input_idx; + register_input.originalType = tmp_ori_type; + register_input.moveType = OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE; + impl_->remove_input_configure_vec_.push_back(register_input); + } + return *this; +} + +OpRegistrationData &OpRegistrationData::ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn) { + if (impl_ != nullptr) { + impl_->parse_subgraph_post_fn_ = subgraph_post_fn; + } + return *this; +} + +ParseSubgraphFunc OpRegistrationData::GetParseSubgraphPostFn() const { + if (impl_ == nullptr) { + return nullptr; + } + return impl_->parse_subgraph_post_fn_; +} + +OpRegistrationData &OpRegistrationData::ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn) { + if (impl_ != nullptr) { + impl_->parse_op_to_graph_fn_ = parse_op_to_graph_fn; + } + return *this; +} + +OpRegistrationData &OpRegistrationData::ParseSubgraphPostFn(const ParseSubgraphFuncV2 &subgraph_post_fn) { + if (impl_ != nullptr) { + impl_->parse_subgraph_post_fn_v2_ = subgraph_post_fn; + } + return *this; +} + +ParseOpToGraphFunc OpRegistrationData::GetParseOpToGraphFn() const { + if (impl_ == nullptr) { + return nullptr; + } + return impl_->parse_op_to_graph_fn_; +} + +Status OpRegistrationData::GetParseSubgraphPostFn(ParseSubgraphFuncV2 &func) const { + if (impl_ == nullptr) { + return FAILED; + } + func = impl_->parse_subgraph_post_fn_v2_; + return SUCCESS; +} + +OpRegistry *OpRegistry::Instance() { + static OpRegistry instance; + return &instance; +} + +namespace { +std::string GetParserKey(const std::string &om_type, const std::string &ori_type) { + return om_type + "_" + ori_type; +} +} // namespace + +bool OpRegistry::Register(const OpRegistrationData ®_data) { + if (reg_data.impl_ == nullptr) { + return false; + } + for (auto ori_type : reg_data.impl_->ori_optype_set_) { + std::string om_ori_type = GetParserKey(reg_data.impl_->om_optype_, ori_type); + if (op_parse_params_fn_map_.find(om_ori_type) != op_parse_params_fn_map_.end()) { + GELOGW("The plugin of op type:%s original type:%s is already registered and will be skipped.", + reg_data.impl_->om_optype_.c_str(), ori_type.c_str()); + continue; + } + + GELOGD("The plugin of type:%s will be registered.", om_ori_type.c_str()); + op_parse_params_fn_map_[om_ori_type] = reg_data.impl_->parseParamFn_; + fusion_op_parse_params_fn_map_[om_ori_type] = reg_data.impl_->fusionParseParamFn_; + fusion_parse_params_by_op_fn_map_[om_ori_type] = reg_data.impl_->fusion_parse_param_by_op_fn_; + parse_params_by_op_func_map_[om_ori_type] = reg_data.impl_->parse_param_by_op_fn_; + remove_input_configure_map_[om_ori_type] = reg_data.impl_->remove_input_configure_vec_; + parse_op_to_graph_fn_map_[om_ori_type] = reg_data.impl_->parse_op_to_graph_fn_; + + if (origin_type_to_om_type_.find(ori_type) == origin_type_to_om_type_.end()) { + origin_type_to_om_type_[ori_type] = reg_data.impl_->om_optype_; + } + } + + if (op_run_mode_map_.find(reg_data.impl_->om_optype_) != op_run_mode_map_.end()) { + GELOGW("The plugin of %s is already registered and will be skipped.", reg_data.impl_->om_optype_.c_str()); + return true; + } + op_run_mode_map_[reg_data.impl_->om_optype_] = reg_data.impl_->imply_type_; + op_types_to_parse_subgraph_post_func_[reg_data.impl_->om_optype_] = reg_data.impl_->parse_subgraph_post_fn_; + op_types_to_parse_subgraph_post_func_v2_[reg_data.impl_->om_optype_] = reg_data.impl_->parse_subgraph_post_fn_v2_; + return true; +} + +domi::ImplyType OpRegistry::GetImplyTypeByOriOpType(const std::string &ori_optype) { + domi::ImplyType result = domi::ImplyType::BUILDIN; + auto iter = origin_type_to_om_type_.find(ori_optype); + if (iter != origin_type_to_om_type_.end()) { + result = GetImplyType(iter->second); + } + return result; +} + +domi::ImplyType OpRegistry::GetImplyType(const std::string &op_type) { + auto it_find = op_run_mode_map_.find(op_type); + if (it_find == op_run_mode_map_.end()) { + return domi::ImplyType::BUILDIN; + } + return it_find->second; +} + +domi::ParseParamByOpFunc OpRegistry::GetParseParamByOperatorFunc(const std::string &ori_type) { + std::string om_type; + auto iter = origin_type_to_om_type_.find(ori_type); + if (iter != origin_type_to_om_type_.end()) { + om_type = iter->second; + } + std::string type = GetParserKey(om_type, ori_type); + auto it_find = parse_params_by_op_func_map_.find(type); + if (it_find == parse_params_by_op_func_map_.end()) { + return nullptr; + } + return it_find->second; +} + +domi::ParseParamFunc OpRegistry::GetParseParamFunc(const std::string &op_type, const std::string &ori_type) { + std::string type = GetParserKey(op_type, ori_type); + auto it_find = op_parse_params_fn_map_.find(type); + if (it_find == op_parse_params_fn_map_.end()) { + return nullptr; + } + return it_find->second; +} + +domi::FusionParseParamFunc OpRegistry::GetFusionParseParamFunc(const std::string &op_type, + const std::string &ori_type) { + std::string type = GetParserKey(op_type, ori_type); + auto it_find = fusion_op_parse_params_fn_map_.find(type); + if (it_find == fusion_op_parse_params_fn_map_.end()) { + return nullptr; + } + return it_find->second; +} + +domi::FusionParseParamByOpFunc OpRegistry::GetFusionParseParamByOpFunc(const std::string &op_type, + const std::string &ori_type) { + std::string type = GetParserKey(op_type, ori_type); + auto it_find = fusion_parse_params_by_op_fn_map_.find(type); + if (it_find == fusion_parse_params_by_op_fn_map_.end()) { + return nullptr; + } + return it_find->second; +} + +domi::ParseSubgraphFunc OpRegistry::GetParseSubgraphPostFunc(const std::string &op_type) { + auto it_find = op_types_to_parse_subgraph_post_func_.find(op_type); + if (it_find == op_types_to_parse_subgraph_post_func_.end()) { + return nullptr; + } + return it_find->second; +} + +Status OpRegistry::GetParseSubgraphPostFunc(const std::string &op_type, + domi::ParseSubgraphFuncV2 &parse_subgraph_func) { + auto it_find = op_types_to_parse_subgraph_post_func_v2_.find(op_type); + if (it_find == op_types_to_parse_subgraph_post_func_v2_.end()) { + return FAILED; + } + parse_subgraph_func = it_find->second; + return SUCCESS; +} + +void OpRegistry::GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType &imply_type) { + for (auto iter = op_run_mode_map_.begin(); iter != op_run_mode_map_.end(); iter++) { + if (iter->second == imply_type) { + vec_op_type.push_back(iter->first); + } + } + return; +} + +const std::vector &OpRegistry::GetRemoveInputConfigure(const std::string &ori_optype) const { + static const std::vector empty_ = {}; + auto iter = origin_type_to_om_type_.find(ori_optype); + if (iter != origin_type_to_om_type_.end()) { + std::string type = GetParserKey(iter->second, ori_optype); + auto it = remove_input_configure_map_.find(type); + if (it != remove_input_configure_map_.end()) { + return it->second; + } + } + return empty_; +} + +bool OpRegistry::GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type) { + auto iter = origin_type_to_om_type_.find(ori_optype); + if (iter != origin_type_to_om_type_.end()) { + om_type = iter->second; + return true; + } + return false; +} + +ParseOpToGraphFunc OpRegistry::GetParseOpToGraphFunc(const std::string &op_type, const std::string &ori_type) { + std::string type = GetParserKey(op_type, ori_type); + auto iter = parse_op_to_graph_fn_map_.find(type); + if (iter == parse_op_to_graph_fn_map_.end()) { + return nullptr; + } + return iter->second; +} +/*lint +e1073*/ +} // namespace domi diff --git a/metadef/register/register_format_transfer.cc b/metadef/register/register_format_transfer.cc new file mode 100644 index 00000000..5d12196d --- /dev/null +++ b/metadef/register/register_format_transfer.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/register_format_transfer.h" + +namespace ge { +namespace formats { +namespace { +struct FormatTransferRegistry { + Status RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) { + src_dst_builder[src][dst] = std::move(builder); + return SUCCESS; + } + std::map> src_dst_builder; +}; + +FormatTransferRegistry &GetFormatTransferRegistry() { + static FormatTransferRegistry registry; + return registry; +} +} // namespace + +FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) { + (void)GetFormatTransferRegistry().RegisterBuilder(src, dst, std::move(builder)); + // RegisterBuilder() always return success, no need to check value +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr BuildFormatTransfer( + const TransArgs &args) { + auto registry = GetFormatTransferRegistry(); + auto dst_builder = registry.src_dst_builder.find(args.src_format); + if (dst_builder == registry.src_dst_builder.end()) { + return nullptr; + } + auto builder_iter = dst_builder->second.find(args.dst_format); + if (builder_iter == dst_builder->second.end()) { + return nullptr; + } + return builder_iter->second(); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool FormatTransferExists(const TransArgs &args) { + auto registry = GetFormatTransferRegistry(); + auto dst_builder = registry.src_dst_builder.find(args.src_format); + if (dst_builder == registry.src_dst_builder.end()) { + return false; + } + return dst_builder->second.count(args.dst_format) > 0; +} +} // namespace formats +} // namespace ge diff --git a/metadef/register/scope/scope_graph.cc b/metadef/register/scope/scope_graph.cc new file mode 100644 index 00000000..cb320b1c --- /dev/null +++ b/metadef/register/scope/scope_graph.cc @@ -0,0 +1,1161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "register/scope/scope_graph_impl.h" +#include +#include "external/register/register.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/string_util.h" +#include "graph/debug/ge_util.h" +#include "graph/ge_tensor.h" +#include "graph/utils/op_desc_utils.h" + +namespace ge { +namespace { +const char *const kTfIdentityType = "Identity"; +const char *const kTfConstType = "Const"; +const char *const kNumerics = "0123456789"; +} // namespace + +Status Scope::ScopeImpl::Init(const std::string &name, const std::string &sub_type, Scope *father_scope) { + name_ = name; + sub_type_ = sub_type; + father_scope_ = father_scope; + return SUCCESS; +} + +Scope::ScopeImpl::~ScopeImpl() { + for (auto &scope : sub_scopes_) { + if (scope.second != nullptr) { + delete scope.second; + scope.second = nullptr; + } + } +} + +void Scope::ScopeImpl::ClearTypeAndSubType() { + sub_type_ = ""; + const std::vector &sub_scopes = GetAllSubScopes(); + for (auto &sub_scope : sub_scopes) { + auto &impl = sub_scope->impl_; + impl->SetSubType(""); + } +} + +void Scope::ScopeImpl::AddNode(ge::OperatorPtr &node_def) { + if (node_def == nullptr) { + GELOGE(PARAM_INVALID, "Input node_def is nullptr."); + return; + } + + nodes_.push_back(node_def); +} + +const std::unordered_map &Scope::ScopeImpl::AllNodesMap() { + if (!all_nodes_map_.empty()) { + return all_nodes_map_; + } + + if (!nodes_.empty()) { + for (auto node : nodes_) { + all_nodes_map_.insert(std::pair(std::string(node->GetName()), node)); + } + } + const std::vector &scopes = GetAllSubScopes(); + for (auto &scope : scopes) { + auto &impl = scope->impl_; + const std::vector &sub_nodes = impl->Nodes(); + if (!sub_nodes.empty()) { + for (auto sub_node : sub_nodes) { + all_nodes_map_.insert(std::pair(std::string(sub_node->GetName()), sub_node)); + } + } + } + return all_nodes_map_; +} + +Scope *Scope::ScopeImpl::GetSubScope(const std::string &scope_name) const { + auto iter = sub_scopes_.find(scope_name); + if (iter != sub_scopes_.end()) { + return iter->second; + } + return nullptr; +} + +const std::vector &Scope::ScopeImpl::GetAllSubScopes() { + if (!all_sub_scopes_.empty()) { + return all_sub_scopes_; + } + + for (auto &iter : sub_scopes_) { + Scope *scope = iter.second; + all_sub_scopes_.push_back(scope); + + std::stack scopes; + scopes.push(scope); + while (!scopes.empty()) { + Scope *scope = scopes.top(); + scopes.pop(); + auto &impl = scope->impl_; + const std::unordered_map &sub_scopes = impl->GetSubScopes(); + for (auto &iter_sub : sub_scopes) { + all_sub_scopes_.push_back(iter_sub.second); + scopes.push(iter_sub.second); + } + } + } + return all_sub_scopes_; +} + +int32_t Scope::ScopeImpl::GetOpTypeNum(const std::string &op_type) const { + auto iter = op_nums_.find(op_type); + if (iter != op_nums_.end()) { + return iter->second; + } else { + return -1; + } +} + +void Scope::ScopeImpl::OpsNumInc(const std::string &op_type) { + auto iter = op_nums_.find(op_type); + if (iter != op_nums_.end()) { + op_nums_[op_type] = iter->second + 1; + } else { + op_nums_[op_type] = 1; + } +} + +const std::string Scope::ScopeImpl::LastName() const { + std::vector names = ge::StringUtils::Split(name_, '/'); + // if vector size is less than 2, there is no multilevel directory, return origin name. + if (names.size() < 2) { + GELOGI("Input name is already the last name, input name:%s.", name_.c_str()); + return name_; + } + std::string last_name = names[names.size() - 2]; // minus 2 to get the last name + return ScopeImpl::TrimScopeIndex(last_name); +} + +std::string Scope::ScopeImpl::TrimScopeIndex(const std::string &scope_name) { + std::string scope_name_new = scope_name; + // deal D_index, only keep name D + auto index = scope_name.find_last_of("_"); + if (index != std::string::npos) { + // index_str after "_" is integer + std::string index_str = scope_name.substr(index + 1, scope_name.length()); + if (index_str.find_first_not_of(kNumerics) != std::string::npos) { + return scope_name; + } + try { + if (std::stoi(index_str.c_str()) > 0) { + scope_name_new = scope_name.substr(0, index); + } + } catch (std::invalid_argument &e) { + scope_name_new = scope_name; + } catch (std::out_of_range &e) { + scope_name_new = scope_name; + } + } + return scope_name_new; +} + +Scope::Scope() {} + +Status Scope::Init(const std::string &name, const std::string &sub_type, Scope *father_scope) { + impl_ = std::unique_ptr(new (std::nothrow) ScopeImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeImpl failed."); + return ge::MEMALLOC_FAILED; + } + + return impl_->Init(name, sub_type, father_scope); +} + +Status Scope::Init(const char *name, const char *sub_type, Scope *father_scope) { + std::string scope_name; + std::string scope_sub_type; + if (name != nullptr) { + scope_name = name; + } + if (sub_type != nullptr) { + scope_sub_type = sub_type; + } + impl_ = std::unique_ptr(new (std::nothrow) ScopeImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeImpl failed."); + return ge::MEMALLOC_FAILED; + } + + return impl_->Init(scope_name, scope_sub_type, father_scope); +} + +Scope::~Scope() {} + +const std::string &Scope::Name() const { + return impl_->Name(); +} + +Status Scope::Name(AscendString &name) const { + name = AscendString(impl_->Name().c_str()); + return SUCCESS; +} + +const std::string &Scope::SubType() const { + return impl_->SubType(); +} + +Status Scope::SubType(AscendString &sub_type) const { + sub_type = AscendString(impl_->SubType().c_str()); + return SUCCESS; +} + +const std::unordered_map &Scope::AllNodesMap() const { + return impl_->AllNodesMap(); +} + +Status Scope::AllNodesMap(std::unordered_map &node_map) const { + std::unordered_map nodes = impl_->AllNodesMap(); + for (auto &node : nodes) { + AscendString tmp(node.first.c_str()); + node_map[tmp] = node.second; + } + return SUCCESS; +} + +Scope *Scope::GetSubScope(const std::string &scope_name) const { + return impl_->GetSubScope(scope_name); +} + +Scope *Scope::GetSubScope(const char *scope_name) const { + std::string str_scope_name; + if (scope_name != nullptr) { + str_scope_name = scope_name; + } + return impl_->GetSubScope(str_scope_name); +} + +const std::string Scope::LastName() const { + return impl_->LastName(); +} + +Status Scope::LastName(AscendString &name) const { + name = AscendString(impl_->LastName().c_str()); + return SUCCESS; +} + +const Scope *Scope::GetFatherScope() const { + return impl_->GetFatherScope(); +} + +const std::vector &Scope::GetAllSubScopes() const { + return impl_->GetAllSubScopes(); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::~InnerNodeInfoImpl() { + operator_.BreakConnect(); +} + +std::string FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::GetFullNodeName(const std::string &relative_name) { + if (fusion_node_name_.empty()) { + return relative_name; + } + return (fusion_node_name_.at(fusion_node_name_.size() - 1) == '/') ? (fusion_node_name_ + relative_name) + : (fusion_node_name_ + "/" + relative_name); +} + +void FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::InsertInput(const std::string &input_node, + int32_t peer_out_idx) { + std::string input_name = (input_node != kInputFromFusionScope) ? GetFullNodeName(input_node) : input_node; + inner_node_inputs_.emplace_back(std::make_pair(input_name, peer_out_idx)); +} + +void FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::InsertOutput(const std::string &output_node, + int32_t peer_in_idx) { + std::string output_name = (output_node != kOutputToFusionScope) ? GetFullNodeName(output_node) : output_node; + inner_node_outputs_.emplace_back(std::make_pair(output_name, peer_in_idx)); +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::BuildOperator() { + operator_ = ge::OperatorFactory::CreateOperator(name_, type_); + if (operator_.GetName() != name_) { + GELOGE(ge::GRAPH_FAILED, "IR for op is not registered, op name:%s, op type:%s", name_.c_str(), type_.c_str()); + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetInputFormat(const std::string &input_name, + const std::string &format) { + ge::TensorDesc input_tesor_desc = operator_.GetInputDesc(input_name); + auto ge_format = ge::TypeUtils::SerialStringToFormat(format); + input_tesor_desc.SetOriginFormat(ge_format); + input_tesor_desc.SetFormat(ge_format); + return operator_.UpdateInputDesc(input_name, input_tesor_desc); +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetOutputFormat(const std::string &output_name, + const std::string &format) { + ge::TensorDesc output_tesor_desc = operator_.GetOutputDesc(output_name); + auto ge_format = ge::TypeUtils::SerialStringToFormat(format); + output_tesor_desc.SetOriginFormat(ge_format); + output_tesor_desc.SetFormat(ge_format); + return operator_.UpdateOutputDesc(output_name, output_tesor_desc); +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetDynamicInputFormat( + const std::string &input_name, uint32_t index, const std::string &format) { + ge::TensorDesc input_tesor_desc = operator_.GetDynamicInputDesc(input_name, index); + auto ge_format = ge::TypeUtils::SerialStringToFormat(format); + input_tesor_desc.SetOriginFormat(ge_format); + input_tesor_desc.SetFormat(ge_format); + return operator_.UpdateDynamicInputDesc(input_name, index, input_tesor_desc); +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetDynamicOutputFormat( + const std::string &output_name, uint32_t index, const std::string &format) { + ge::TensorDesc output_tesor_desc = operator_.GetDynamicOutputDesc(output_name, index); + auto ge_format = ge::TypeUtils::SerialStringToFormat(format); + output_tesor_desc.SetOriginFormat(ge_format); + output_tesor_desc.SetFormat(ge_format); + return operator_.UpdateDynamicOutputDesc(output_name, index, output_tesor_desc); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const std::string &fusion_node_name) { + impl_ = std::unique_ptr(new (std::nothrow) InnerNodeInfoImpl(fusion_node_name)); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const char *fusion_node_name) { + std::string str_fusion_node_name; + if (fusion_node_name != nullptr) { + str_fusion_node_name = fusion_node_name; + } + impl_ = std::unique_ptr(new (std::nothrow) InnerNodeInfoImpl(str_fusion_node_name)); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, + const std::string &type) { + impl_ = std::unique_ptr(new (std::nothrow) InnerNodeInfoImpl(fusion_node_name, name, type)); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const char *fusion_node_name, const char *name, + const char *type) { + std::string node_name; + if (fusion_node_name != nullptr) { + node_name = fusion_node_name; + } + std::string str_name; + if (name != nullptr) { + str_name = name; + } + std::string str_type; + if (type != nullptr) { + str_type = type; + } + impl_ = std::unique_ptr(new (std::nothrow) InnerNodeInfoImpl(node_name, + str_name, str_type)); +} + +FusionScopesResult::InnerNodeInfo::InnerNodeInfo(FusionScopesResult::InnerNodeInfo &&other) noexcept + : impl_(std::move(other.impl_)) {} + +FusionScopesResult::InnerNodeInfo::~InnerNodeInfo() {} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::operator=( + FusionScopesResult::InnerNodeInfo &&other) noexcept { + if (&other != this) { + impl_ = std::move(other.impl_); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetName(const std::string &name) { + if (impl_ != nullptr) { + impl_->SetName(name); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetName(const char *name) { + if (impl_ != nullptr && name != nullptr) { + std::string str_name = name; + impl_->SetName(str_name); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetType(const std::string &type) { + if (impl_ != nullptr) { + impl_->SetType(type); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetType(const char *type) { + if (impl_ != nullptr && type != nullptr) { + std::string str_type = type; + impl_->SetType(str_type); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertInput(const std::string &input_node, + int32_t peer_out_idx) { + if (impl_ != nullptr) { + impl_->InsertInput(input_node, peer_out_idx); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertInput(const char *input_node, + int32_t peer_out_idx) { + if (impl_ != nullptr && input_node != nullptr) { + std::string str_input_node = input_node; + impl_->InsertInput(str_input_node, peer_out_idx); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertOutput(const std::string &output_node, + int32_t peer_in_idx) { + if (impl_ != nullptr) { + impl_->InsertOutput(output_node, peer_in_idx); + } + return *this; +} + +FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertOutput(const char *output_node, + int32_t peer_in_idx) { + if (impl_ != nullptr && output_node != nullptr) { + std::string str_output_node = output_node; + impl_->InsertOutput(str_output_node, peer_in_idx); + } + return *this; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::BuildInnerNode() { + if (impl_ != nullptr) { + return impl_->BuildOperator(); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::Operator *FusionScopesResult::InnerNodeInfo::MutableOperator() { + if (impl_ != nullptr) { + return impl_->MutableOperator(); + } + return nullptr; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetInputFormat(const std::string &input_name, + const std::string &format) { + if (impl_ != nullptr) { + return impl_->SetInputFormat(input_name, format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetInputFormat(const char *input_name, + const char *format) { + if (impl_ != nullptr && input_name != nullptr && format != nullptr) { + std::string str_input_name = input_name; + std::string str_format = format; + return impl_->SetInputFormat(str_input_name, str_format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetOutputFormat(const std::string &output_name, + const std::string &format) { + if (impl_ != nullptr) { + return impl_->SetOutputFormat(output_name, format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetOutputFormat(const char *output_name, + const char *format) { + if (impl_ != nullptr && output_name != nullptr && format != nullptr) { + std::string str_output_name = output_name; + std::string str_format = format; + return impl_->SetOutputFormat(str_output_name, str_format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicInputFormat(const std::string &input_name, uint32_t index, + const std::string &format) { + if (impl_ != nullptr) { + return impl_->SetDynamicInputFormat(input_name, index, format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicInputFormat(const char *input_name, uint32_t index, + const char *format) { + if (impl_ != nullptr && input_name != nullptr && format != nullptr) { + std::string str_input_name = input_name; + std::string str_format = format; + return impl_->SetDynamicInputFormat(str_input_name, index, str_format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicOutputFormat(const std::string &output_name, + uint32_t index, const std::string &format) { + if (impl_ != nullptr) { + return impl_->SetDynamicOutputFormat(output_name, index, format); + } + return ge::GRAPH_PARAM_INVALID; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicOutputFormat(const char *output_name, + uint32_t index, const char *format) { + if (impl_ != nullptr && output_name != nullptr && format != nullptr) { + std::string str_output_name = output_name; + std::string str_format = format; + return impl_->SetDynamicOutputFormat(str_output_name, index, str_format); + } + return ge::GRAPH_PARAM_INVALID; +} + +std::string FusionScopesResult::InnerNodeInfo::GetName() const { + if (impl_ != nullptr) { + return impl_->GetName(); + } + return ""; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::GetName(AscendString &name) const { + if (impl_ != nullptr) { + name = AscendString(impl_->GetName().c_str()); + } + return GRAPH_SUCCESS; +} + +std::string FusionScopesResult::InnerNodeInfo::GetType() const { + if (impl_ != nullptr) { + return impl_->GetType(); + } + return ""; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::GetType(AscendString &type) const { + if (impl_ != nullptr) { + type = AscendString(impl_->GetType().c_str()); + } + return GRAPH_SUCCESS; +} + +std::vector> FusionScopesResult::InnerNodeInfo::GetInputs() const { + std::vector> tmp; + if (impl_ != nullptr) { + return impl_->GetInputs(); + } + return tmp; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::GetInputs( + std::vector> &inputs) const { + std::vector> tmps; + if (impl_ != nullptr) { + tmps = impl_->GetInputs(); + } + for (auto &tmp : tmps) { + inputs.emplace_back(std::pair(AscendString(tmp.first.c_str()), tmp.second)); + } + return GRAPH_SUCCESS; +} + +std::vector> FusionScopesResult::InnerNodeInfo::GetOutputs() const { + std::vector> tmp; + if (impl_ != nullptr) { + return impl_->GetOutputs(); + } + return tmp; +} + +ge::graphStatus FusionScopesResult::InnerNodeInfo::GetOutputs( + std::vector> &outputs) const { + std::vector> tmps; + if (impl_ != nullptr) { + tmps = impl_->GetOutputs(); + } + for (auto &tmp : tmps) { + outputs.emplace_back(std::pair(tmp.first.c_str(), tmp.second)); + } + return GRAPH_SUCCESS; +} + +void FusionScopesResult::FusionScopesResultImpl::AddNodes(std::vector nodes) { + nodes_.insert(nodes_.end(), nodes.begin(), nodes.end()); +} + +void FusionScopesResult::FusionScopesResultImpl::InsertInputs(const std::string &inner_op_name, + const std::vector &index_map) { + inputs_.insert(make_pair(inner_op_name, index_map)); +} +void FusionScopesResult::FusionScopesResultImpl::InsertOutputs(const std::string &inner_op_name, + const std::vector &index_map) { + outputs_.insert(make_pair(inner_op_name, index_map)); +} + +bool FusionScopesResult::FusionScopesResultImpl::FindNodes(const std::string &node_name) const { + for (auto &node : nodes_) { + if (node->GetName() == node_name) { + return true; + } + } + return false; +} + +bool FusionScopesResult::FusionScopesResultImpl::FindScopes(const std::string &scope_name) const { + for (auto &scope : scopes_) { + if (scope->Name().length() < scope_name.length() && scope_name.find(scope->Name()) == 0) { + return true; + } + } + return false; +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::AddInnerNode(const std::string &name, + const std::string &type) { + inner_node_infos_.emplace_back(InnerNodeInfo(name_, name, type)); + return &(inner_node_infos_[inner_node_infos_.size() - 1]); +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::MutableRecentInnerNode() { + size_t size = inner_node_infos_.size(); + if (size >= 1) { + return &(inner_node_infos_[size - 1]); + } + return nullptr; +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::MutableInnerNode(uint32_t index) { + if (index < inner_node_infos_.size()) { + return &(inner_node_infos_[index]); + } + return nullptr; +} + +FusionInnerNodesInfo FusionScopesResult::FusionScopesResultImpl::GetInnerNodesInfo() { + FusionInnerNodesInfo nodes_info; + for (auto &info : inner_node_infos_) { + nodes_info.emplace_back( + std::make_tuple(info.GetName(), info.GetType(), info.GetInputs(), info.GetOutputs(), info.MutableOperator())); + } + return nodes_info; +} + +ge::graphStatus FusionScopesResult::FusionScopesResultImpl::CheckInnerNodesInfo() { + size_t input_from_scope = 0; + size_t output_to_scope = 0; + std::set name_set; + for (const auto &info : inner_node_infos_) { + if (!(name_set.insert(info.GetName()).second)) { + GELOGE(ge::GRAPH_PARAM_INVALID, "There are duplicate internal node name, please check."); + return ge::GRAPH_PARAM_INVALID; + } + for (auto input : info.GetInputs()) { + input_from_scope += (input.first == kInputFromFusionScope) ? 1 : 0; + } + for (auto input : info.GetOutputs()) { + output_to_scope += (input.first == kOutputToFusionScope) ? 1 : 0; + } + } + size_t scope_input = 0; + size_t scope_output = 0; + for (const auto &input : inputs_) { + for (const auto &idx : input.second) { + scope_input += (idx != kFusionDisableIndex) ? 1 : 0; + } + } + for (const auto &output : outputs_) { + for (const auto &idx : output.second) { + scope_output += (idx != kFusionDisableIndex) ? 1 : 0; + } + } + if ((input_from_scope != scope_input) || (output_to_scope != scope_output)) { + GELOGE(ge::GRAPH_PARAM_INVALID, + "Input or Output mismatched, please check. " + "Inner input_from_scope:%zu, scope input:%zu, " + "inner output_to_scope:%zu, scope output:%zu.", + input_from_scope, scope_input, output_to_scope, scope_output); + return ge::GRAPH_PARAM_INVALID; + } + return ge::GRAPH_SUCCESS; +} + +FusionScopesResult::FusionScopesResult() {} + +Status FusionScopesResult::Init() { + impl_ = std::unique_ptr(new (std::nothrow) FusionScopesResultImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of FusionScopesResultImpl failed."); + return ge::MEMALLOC_FAILED; + } + + return SUCCESS; +} + +FusionScopesResult::~FusionScopesResult() {} + +void FusionScopesResult::SetName(const std::string &name) { + impl_->SetName(name); +} + +void FusionScopesResult::SetName(const char *name) { + std::string str_name; + if (name != nullptr) { + str_name = name; + } + impl_->SetName(str_name); +} + +void FusionScopesResult::SetType(const std::string &type) { + impl_->SetType(type); +} + +void FusionScopesResult::SetType(const char *type) { + std::string str_type; + if (type != nullptr) { + str_type = type; + } + impl_->SetType(str_type); +} + +void FusionScopesResult::SetDescription(const std::string &description) { + impl_->SetDescription(description); +} + +void FusionScopesResult::SetDescription(const char *description) { + std::string str_desc; + if (description != nullptr) { + str_desc = description; + } + impl_->SetDescription(str_desc); +} + +const std::string &FusionScopesResult::Name() const { + return impl_->Name(); +} + +const Status FusionScopesResult::Name(AscendString &name) const { + name = AscendString(impl_->Name().c_str()); + return SUCCESS; +} + +const std::vector &FusionScopesResult::Nodes() const { + return impl_->Nodes(); +} + +void FusionScopesResult::InsertInputs(const std::string &inner_op_name, const std::vector &index_map) { + impl_->InsertInputs(inner_op_name, index_map); +} + +void FusionScopesResult::InsertInputs(const char *inner_op_name, const std::vector &index_map) { + std::string op_name; + if (inner_op_name != nullptr) { + op_name = inner_op_name; + } + impl_->InsertInputs(op_name, index_map); +} + +void FusionScopesResult::InsertOutputs(const std::string &inner_op_name, const std::vector &index_map) { + impl_->InsertOutputs(inner_op_name, index_map); +} + +void FusionScopesResult::InsertOutputs(const char *inner_op_name, const std::vector &index_map) { + std::string op_name; + if (inner_op_name != nullptr) { + op_name = inner_op_name; + } + impl_->InsertOutputs(op_name, index_map); +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::AddInnerNode(const std::string &name, const std::string &type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); + return nullptr; + } + return impl_->AddInnerNode(name, type); +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::AddInnerNode(const char *name, const char *type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); + return nullptr; + } + std::string str_name; + if (name != nullptr) { + str_name = name; + } + std::string str_type; + if (type != nullptr) { + str_type = type; + } + return impl_->AddInnerNode(str_name, str_type); +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::MutableRecentInnerNode() { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); + return nullptr; + } + return impl_->MutableRecentInnerNode(); +} + +FusionScopesResult::InnerNodeInfo *FusionScopesResult::MutableInnerNode(uint32_t index) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); + return nullptr; + } + return impl_->MutableInnerNode(index); +} + +ge::graphStatus FusionScopesResult::CheckInnerNodesInfo() { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); + return ge::GRAPH_PARAM_INVALID; + } + return impl_->CheckInnerNodesInfo(); +} + +Status ScopeTree::ScopeTreeImpl::Init() { + root_ = new (std::nothrow) Scope(); + if (root_ == nullptr) { + GELOGE(FAILED, "Alloc root scope failed."); + return FAILED; + } + if (root_->Init("root") != SUCCESS) { + GELOGE(FAILED, "Init root scope failed."); + return FAILED; + } + scopes_.push_back(root_); + return SUCCESS; +} + +ScopeTree::ScopeTreeImpl::~ScopeTreeImpl() { + if (root_ != nullptr) { + delete root_; + root_ = nullptr; + } +} + +void ScopeTree::ScopeTreeImpl::AddNodeToScope(ge::OperatorPtr &node_def) { + if (node_def == nullptr) { + GELOGE(PARAM_INVALID, "Input node_def is nullptr."); + return; + } + const std::string &node_name = node_def->GetName(); + Scope *super_scope = root_; + + std::vector scopes = SplitNodeName(node_name, '/'); + for (uint32_t i = 0; i < scopes.size(); ++i) { + auto &impl = super_scope->impl_; + impl->OpsNumInc(node_def->GetOpType()); + + if (i == (scopes.size() - 1)) { + impl->AddNode(node_def); + } else { + Scope *sub_scope = impl->GetSubScope(scopes[i]); + if (sub_scope == nullptr) { + sub_scope = new (std::nothrow) Scope(); + if (sub_scope == nullptr) { + GELOGE(FAILED, "Alloc Scope failed."); + return; + } + if (sub_scope->Init(scopes[i], "", super_scope) != SUCCESS) { + GELOGE(FAILED, "Init Scope failed."); + delete sub_scope; + sub_scope = nullptr; + return; + } + scopes_.push_back(sub_scope); + impl->AddSubScope(sub_scope); + } + super_scope = sub_scope; + } + } +} + +std::vector ScopeTree::ScopeTreeImpl::SplitNodeName(const std::string &node_name, const char delim) const { + std::vector items; + std::vector scopes; + if (node_name == "") return items; + + items = ge::StringUtils::Split(node_name, delim); + std::string scope; + for (uint32_t i = 0; i < items.size(); ++i) { + if (items[i].length() == 0) { + continue; + } + + if (i == 0) { + scope = items[i]; + } else { + scope = scope + items[i]; + } + + if (i != (items.size() - 1)) { + scope = scope + delim; + } + + scopes.push_back(scope); + } + + return scopes; +} + +ScopeTree::ScopeTree() {} + +ScopeTree::~ScopeTree() {} + +Status ScopeTree::Init() { + impl_ = std::unique_ptr(new (std::nothrow) ScopeTreeImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of FusionScopesResultImpl failed."); + return ge::MEMALLOC_FAILED; + } + return impl_->Init(); +} + +const std::vector &ScopeTree::GetAllScopes() const { + return impl_->GetAllScopes(); +} + +Status ScopeGraph::ScopeGraphImpl::Init() { + scope_tree_ = new (std::nothrow) ScopeTree(); + if (scope_tree_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Alloc scope tree failed."); + return ge::MEMALLOC_FAILED; + } + Status ret = scope_tree_->Init(); + if (ret != SUCCESS) { + GELOGE(FAILED, "Scope tree init failed."); + return FAILED; + } + return SUCCESS; +} + +ScopeGraph::ScopeGraphImpl::~ScopeGraphImpl() { + if (scope_tree_ != nullptr) { + delete scope_tree_; + scope_tree_ = nullptr; + } + + for (auto &fusion_result : fusion_results_) { + if (fusion_result.second != nullptr) { + delete fusion_result.second; + fusion_result.second = nullptr; + } + } + + for (auto item : nodes_map_) { + item.second->BreakConnect(); + } +} + +void ScopeGraph::ScopeGraphImpl::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { + if (graph_def == nullptr) { + GELOGE(PARAM_INVALID, "Input graph_def is nullptr."); + return; + } + + for (int i = 0; i < graph_def->node_size(); ++i) { + const domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); + ge::OperatorPtr op(new (std::nothrow) ge::Operator(node_def->name(), node_def->op())); + if (op == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make shared_ptr falied."); + return; + } + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op); + Status ret = domi::AutoMappingFn(node_def, *op); + if (ret != SUCCESS) { + GELOGE(FAILED, "Op: %s call auto mapping function failed.", op_desc->GetName().c_str()); + return; + } + + for (int i = 0; i < node_def->input_size(); i++) { + ge::GeTensorDesc tensor_desc; + tensor_desc.SetName(node_def->input(i)); + op_desc->AddInputDesc(tensor_desc); + } + + nodes_map_.emplace(op->GetName(), op); + if (op->GetOpType() != kTfIdentityType || op->GetOpType() != kTfConstType) { + auto &impl = scope_tree_->impl_; + impl->AddNodeToScope(op); + } + } +} + +void ScopeGraph::ScopeGraphImpl::AddFusionScopesResult(FusionScopesResult *result) { + if (result == nullptr) { + GELOGE(PARAM_INVALID, "Input params invalid, result is nullptr."); + return; + } + fusion_results_[result->Name()] = result; +} + +bool ScopeGraph::ScopeGraphImpl::IsFusionOpChild(const std::string &node_name, + std::vector &info_list) { + bool find = false; + for (auto &fusion_result : fusion_results_) { + FusionScopesResult *fusion_node = fusion_result.second; + auto &impl = fusion_node->impl_; + + if (impl->FindNodes(node_name) || impl->FindScopes(node_name)) { + ScopeFusionOpInfo info; + info.fusion_node_name = fusion_node->Name(); + info.fusion_op_type = impl->Type(); + info.node_name = node_name; + info.description = impl->Description(); + info.scope_pass = true; + info_list.push_back(info); + + find = true; + } + } + + return find; +} + +bool ScopeGraph::ScopeGraphImpl::FusionOpChildIgnore(const ScopeFusionOpInfo &info) { + if (!(GetFusionResultInputOrOutput(info, true).empty()) || !(GetFusionResultInputOrOutput(info, false).empty())) { + return false; + } + return true; +} + +std::vector ScopeGraph::ScopeGraphImpl::GetFusionResultInputOrOutput(const ScopeFusionOpInfo &info, + bool input) { + std::vector indexs; + auto fusion_iter = fusion_results_.find(info.fusion_node_name); + if (fusion_iter == fusion_results_.end()) { + GELOGE(FAILED, "Get fusion result failed, not found node:%s", info.fusion_node_name.c_str()); + return indexs; + } + + FusionScopesResult *fusion_node = fusion_iter->second; + std::unordered_map> inout_map; + auto &impl = fusion_node->impl_; + if (input) { + inout_map = impl->GetInputs(); + } else { + inout_map = impl->GetOutputs(); + } + + for (auto &iter : inout_map) { + std::string input_name = iter.first; + std::string op_name = (info.node_name.length() > input_name.length()) + ? info.node_name.substr(info.node_name.length() - input_name.length()) + : info.node_name; + if (input_name == op_name) { + indexs.insert(indexs.end(), iter.second.begin(), iter.second.end()); + break; + } + } + + return indexs; +} + +bool ScopeGraph::ScopeGraphImpl::IsFusionOp(const domi::tensorflow::NodeDef *node_def) { + if (node_def == nullptr) { + GELOGE(PARAM_INVALID, "Input node_def is nullptr."); + return false; + } + for (auto &fusion_result : fusion_results_) { + FusionScopesResult *fusion_node = fusion_result.second; + auto &impl = fusion_node->impl_; + if (impl->Type() == node_def->op() && fusion_node->Name() == node_def->name()) { + return true; + } + } + return false; +} + +Status ScopeGraph::ScopeGraphImpl::GetInputOrOutputIndex(const ScopeFusionOpInfo &info, int32_t old_index, + bool input, int32_t &new_index) { + if (old_index == -1) { + new_index = -1; + return SUCCESS; + } + + std::vector indexs = GetFusionResultInputOrOutput(info, input); + GELOGD("GetNodeindex, node_name:%s, fusion_node_name:%s, fusion_op_type:%s, old_index:%d, size:%zu.", + info.node_name.c_str(), info.fusion_node_name.c_str(), info.fusion_op_type.c_str(), old_index, indexs.size()); + if ((int32_t)indexs.size() < (old_index + 1)) { + GELOGD("GetNodeindex fusionDisableIndex, node_name:%s, fusion_node_name:%s, fusion_op_type:%s, old_index:%d .", + info.node_name.c_str(), info.fusion_node_name.c_str(), info.fusion_op_type.c_str(), old_index); + new_index = kFusionDisableIndex; + } else { + new_index = indexs[old_index]; + } + GELOGD("RESULT: new index:%d.", new_index); + return SUCCESS; +} + +FusionScopesResult *ScopeGraph::ScopeGraphImpl::GetFusionScopesResults( + const domi::tensorflow::NodeDef *node_def) const { + if (node_def == nullptr) { + return nullptr; + } + return GetFusionScopesResults(node_def->name()); +} + +FusionScopesResult *ScopeGraph::ScopeGraphImpl::GetFusionScopesResults(const string &node_name) const { + auto iter = fusion_results_.find(node_name); + if (iter != fusion_results_.end()) { + return iter->second; + } else { + return nullptr; + } +} + +ScopeGraph::ScopeGraph() {} + +ScopeGraph::~ScopeGraph() {} + +Status ScopeGraph::Init() { + impl_ = std::unique_ptr(new (std::nothrow) ScopeGraphImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeGraphImpl failed."); + return ge::MEMALLOC_FAILED; + } + return impl_->Init(); +} + +const ScopeTree *ScopeGraph::GetScopeTree() const { + return impl_->GetScopeTree(); +} + +const std::unordered_map &ScopeGraph::GetNodesMap() const { + return impl_->GetNodesMap(); +} + +Status ScopeGraph::GetNodesMap(std::unordered_map &nodes_map) const { + std::unordered_map tmps; + if (impl_ != nullptr) { + tmps = impl_->GetNodesMap(); + } + for (auto &tmp : tmps) { + AscendString node(tmp.first.c_str()); + nodes_map[node] = tmp.second; + } + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/metadef/register/scope/scope_pass.cc b/metadef/register/scope/scope_pass.cc new file mode 100644 index 00000000..ab06e794 --- /dev/null +++ b/metadef/register/scope/scope_pass.cc @@ -0,0 +1,327 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "register/scope/scope_pass_impl.h" +#include +#include +#include "register/scope/scope_graph_impl.h" +#include "register/scope/scope_pattern_impl.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_util.h" + +namespace ge { +ScopesResult::ScopesResult() { + impl_ = std::unique_ptr(new (std::nothrow) ScopesResultImpl); +} + +ScopesResult::ScopesResult(ScopesResult const &result) { + impl_ = std::unique_ptr(new (std::nothrow) ScopesResultImpl); + const std::vector &scopes = result.impl_->GetScopes(); + const std::vector &nodes = result.impl_->GetNodes(); + impl_->SetScopes(scopes); + impl_->SetNodes(nodes); +} +ScopesResult &ScopesResult::operator=(ScopesResult const &result) { + if (&result == this) { + return *this; + } + + const std::vector &scopes = result.impl_->GetScopes(); + const std::vector &nodes = result.impl_->GetNodes(); + impl_->SetScopes(scopes); + impl_->SetNodes(nodes); + return *this; +} + +ScopesResult::~ScopesResult() {} + +void ScopesResult::SetScopes(std::vector &scopes) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetScopes(), ScopesResult is not properly initialized."); + return; + } + + impl_->SetScopes(scopes); +} + +void ScopesResult::SetNodes(std::vector &nodes) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetNodes(), ScopesResult is not properly initialized."); + return; + } + + impl_->SetNodes(nodes); +} + +ScopeBasePass::ScopeBasePassImpl::~ScopeBasePassImpl() { + for (auto &scope_patterns : patterns_) { + for (auto &batch_patterns : scope_patterns) { + for (auto &pattern : batch_patterns) { + if (pattern != nullptr) { + delete pattern; + pattern = nullptr; + } + } + } + } +} + +Status ScopeBasePass::ScopeBasePassImpl::AddFusionScopesResultToScopeGraph(std::shared_ptr &scope_graph, + std::vector &scope_results) { + for (auto &rlt : scope_results) { + FusionScopesResult *fusion_rlt = new (std::nothrow) FusionScopesResult(); + if (fusion_rlt == nullptr) { + GELOGE(FAILED, "Alloc fusion_rlt failed."); + return FAILED; + } + if (fusion_rlt->Init() != SUCCESS) { + GELOGE(FAILED, "Init fusion_rlt failed."); + delete fusion_rlt; + fusion_rlt = nullptr; + return FAILED; + } + auto &impl_fusion_rlt = fusion_rlt->impl_; + auto &impl_scope_rlt = rlt.impl_; + if (impl_scope_rlt == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopesResult is not properly initialized."); + delete fusion_rlt; + fusion_rlt = nullptr; + continue; + } + + impl_fusion_rlt->AddNodes(impl_scope_rlt->GetNodes()); + impl_fusion_rlt->AddScopes(impl_scope_rlt->GetScopes()); + + parent_->GenerateFusionResult(impl_scope_rlt->GetScopes(), fusion_rlt); + if (impl_fusion_rlt->Type() == kScopeInvalidType) { + GELOGE(FAILED, "Failed to set inner node for fusion op %s.", impl_fusion_rlt->Type().c_str()); + delete fusion_rlt; + return FAILED; + } + auto &impl_scope_graph = scope_graph->impl_; + impl_scope_graph->AddFusionScopesResult(fusion_rlt); + } + + return SUCCESS; +} + +Status ScopeBasePass::ScopeBasePassImpl::Run(std::shared_ptr &scope_graph) { + GE_CHECK_NOTNULL(scope_graph); + const ScopeTree *scope_tree = scope_graph->GetScopeTree(); + GE_CHECK_NOTNULL(scope_tree); + GE_CHECK_NOTNULL(parent_); + patterns_ = parent_->DefinePatterns(); + std::vector results; + if (!MatchAllBatches(scope_tree, results)) { + GELOGI("[scope_fusion] Scope pass %s's patterns is not matched and ignored.", parent_->PassName().c_str()); + return domi::SCOPE_NOT_CHANGED; + } + GELOGI("[scope_fusion] Scope pass %s's patterns is matched.", parent_->PassName().c_str()); + + std::vector scope_results; + Status ret = parent_->LastMatchScopesAndOPs(scope_graph, scope_results); + if (ret != SUCCESS) { + for (auto &result : results) { + GE_CHECK_NOTNULL(result); + auto &impl_scope = result->impl_; + impl_scope->ClearTypeAndSubType(); + } + GELOGW("[scope_fusion] Scope pass %s's patterns is ignored, because LastMatchScopesAndOPs failed.", + parent_->PassName().c_str()); + return domi::SCOPE_NOT_CHANGED; + } + + if (!results.empty()) { + ret = AddFusionScopesResultToScopeGraph(scope_graph, scope_results); + if (ret != SUCCESS) { + GELOGE(FAILED, "Scope pass %s add fusion scopes result to scope graph failed.", parent_->PassName().c_str()); + return FAILED; + } + } else { + GELOGI("[scope_fusion] Scope pass %s not match any scope.", parent_->PassName().c_str()); + } + + ret = PrintFusionScopeInfo(scope_graph); + if (ret != SUCCESS) { + GELOGI("[scope_fusion] Print scope pass %s fusion info failed.", parent_->PassName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +bool ScopeBasePass::ScopeBasePassImpl::MatchAllBatches(const ScopeTree *scope_tree, std::vector &results) { + if (scope_tree == nullptr) { + GELOGE(PARAM_INVALID, "Input param [scope_tree] is nullptr."); + return false; + } + + for (auto &scope_patterns : patterns_) { + std::vector tmp_results; + std::vector last_results; + uint32_t batch_num = 0; + for (auto &batch_patterns : scope_patterns) { + ++batch_num; + std::vector one_results; + bool is_matched = MatchOneBatch(scope_tree, batch_patterns, one_results); + if (!is_matched) { + break; + } + if (batch_num == scope_patterns.size()) { + last_results.insert(last_results.end(), one_results.begin(), one_results.end()); + } else { + tmp_results.insert(tmp_results.end(), one_results.begin(), one_results.end()); + } + } + for (auto &tmp : tmp_results) { + bool rollback = true; + for (auto &result : last_results) { + if ((result->Name().length() <= tmp->Name().length()) && (tmp->Name().find(result->Name()) == 0)) { + rollback = false; + break; + } + } + if (rollback) { + auto &impl = tmp->impl_; + impl->SetSubType(""); + } + } + results.insert(results.end(), last_results.begin(), last_results.end()); + } + + return !(results.empty()); +} + +bool ScopeBasePass::ScopeBasePassImpl::MatchOneBatch(const ScopeTree *scope_tree, + const std::vector &patternlist, + std::vector &results) { + if (scope_tree == nullptr) { + GELOGE(PARAM_INVALID, "Input param [scope_tree] is nullptr"); + return false; + } + + int32_t find = 0; + auto &impl_scope_tree = scope_tree->impl_; + const Scope *root = impl_scope_tree->Root(); + if (root != nullptr) { + auto &impl_scope = root->impl_; + const std::unordered_map &sub_scopes = impl_scope->GetSubScopes(); + for (auto &pattern : patternlist) { + for (auto &scope : sub_scopes) { + if (MatchOneScope(pattern, scope.second, results)) { + ++find; + } + } + } + } + + return find > 0 ? true : false; +} + +bool ScopeBasePass::ScopeBasePassImpl::MatchOneScope(const ScopePattern *pattern, Scope *scope, + std::vector &results) { + if (pattern == nullptr || scope == nullptr) { + GELOGE(PARAM_INVALID, "Input param is nullptr"); + return false; + } + auto &impl_scope_pattern = pattern->impl_; + if (impl_scope_pattern == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopePattern is not properly initialized."); + return false; + } + if (impl_scope_pattern->Match(scope)) { + auto &scope_impl = scope->impl_; + scope_impl->SetSubType(impl_scope_pattern->SubType()); + results.push_back(scope); + return true; + } + int32_t find = 0; + std::stack scopes; + scopes.push(scope); + while (!scopes.empty()) { + Scope *current_scope = scopes.top(); + scopes.pop(); + auto ¤t_scope_impl = current_scope->impl_; + const std::unordered_map &sub_scopes = current_scope_impl->GetSubScopes(); + for (auto &sub_scope : sub_scopes) { + if (impl_scope_pattern->Match(sub_scope.second)) { + auto &sub_scope_impl = sub_scope.second->impl_; + sub_scope_impl->SetSubType(impl_scope_pattern->SubType()); + results.push_back(sub_scope.second); + ++find; + } else { + scopes.push(sub_scope.second); + } + } + } + return find > 0 ? true : false; +} + +Status ScopeBasePass::ScopeBasePassImpl::PrintFusionScopeInfo(std::shared_ptr &scope_graph) { + if (scope_graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param scope_graph is nullptr."); + return PARAM_INVALID; + } + auto &impl_scope_graph = scope_graph->impl_; + const std::unordered_map &final_results = impl_scope_graph->FusionScopesResults(); + for (auto &result : final_results) { + if (result.second == nullptr) { + GELOGE(PARAM_INVALID, "Fusion scope is nullptr."); + return PARAM_INVALID; + } + GELOGI("FusionScope:%s", result.second->Name().c_str()); + auto &impl = result.second->impl_; + const std::unordered_map> &inputs = impl->GetInputs(); + for (auto &input : inputs) { + std::vector indexs = input.second; + for (int32_t index : indexs) { + GELOGI("FusionScope input node:%s,%d", input.first.c_str(), index); + } + } + + const std::unordered_map> &outputs = impl->GetOutputs(); + for (auto &output : outputs) { + std::vector indexs = output.second; + for (int32_t index : indexs) { + GELOGI("FusionScope output node:%s,%d", output.first.c_str(), index); + } + } + + for (auto &scope : impl->Scopes()) { + if (scope == nullptr) { + GELOGE(PARAM_INVALID, "Scope in fusion scope is nullptr."); + return PARAM_INVALID; + } + GELOGI("FusionScope GetScope:%s", scope->Name().c_str()); + } + + for (auto &node : result.second->Nodes()) { + if (node == nullptr) { + GELOGE(PARAM_INVALID, "Node in scope is nullptr."); + return PARAM_INVALID; + } + GELOGI("FusionScope Node:%s", node->GetName().c_str()); + } + } + return SUCCESS; +} + +ScopeBasePass::ScopeBasePass() { + impl_ = std::unique_ptr(new (std::nothrow) ScopeBasePassImpl(this)); +} + +ScopeBasePass::~ScopeBasePass() {} +} // namespace ge diff --git a/metadef/register/scope/scope_pass_registry.cc b/metadef/register/scope/scope_pass_registry.cc new file mode 100644 index 00000000..311654c7 --- /dev/null +++ b/metadef/register/scope/scope_pass_registry.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "register/scope/scope_pass_registry_impl.h" +#include +#include +#include +#include +#include "graph/debug/ge_log.h" +#include "external/register/scope/scope_fusion_pass_register.h" + +using ge::MEMALLOC_FAILED; + +namespace ge { +struct CreatePassFnPack { + bool is_enable; + ScopeFusionPassRegistry::CreateFn create_fn; +}; + +void ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::RegisterScopeFusionPass( + const std::string &pass_name, ScopeFusionPassRegistry::CreateFn create_fn, bool is_general) { + std::lock_guard lock(mu_); + auto iter = std::find(pass_names_.begin(), pass_names_.end(), pass_name); + if (iter != pass_names_.end()) { + GELOGW("The scope fusion pass has been registered and will not overwrite the previous one, pass name = %s.", + pass_name.c_str()); + return; + } + + CreatePassFnPack create_fn_pack; + create_fn_pack.is_enable = is_general; + create_fn_pack.create_fn = create_fn; + create_fn_packs_[pass_name] = create_fn_pack; + pass_names_.push_back(pass_name); + GELOGI("Register scope fusion pass, pass name = %s, is_enable = %d.", pass_name.c_str(), is_general); +} + +ScopeFusionPassRegistry::CreateFn ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::GetCreateFn( + const std::string &pass_name) { + std::lock_guard lock(mu_); + auto it = create_fn_packs_.find(pass_name); + if (it == create_fn_packs_.end()) { + GELOGW("Scope fusion pass is not registered. pass name = %s.", pass_name.c_str()); + return nullptr; + } + + CreatePassFnPack &create_fn_pack = it->second; + if (create_fn_pack.is_enable) { + return create_fn_pack.create_fn; + } else { + GELOGW("The scope fusion pass is disabled, pass name = %s", pass_name.c_str()); + return nullptr; + } +} + +std::vector ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::GetAllRegisteredPasses() { + std::lock_guard lock(mu_); + std::vector all_passes; + for (size_t i = 0; i < pass_names_.size(); ++i) { + if (create_fn_packs_[pass_names_[i]].is_enable) { + all_passes.push_back(pass_names_[i]); + } + } + + return all_passes; +} + +bool ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::SetPassEnableFlag( + const std::string pass_name, const bool flag) { + std::lock_guard lock(mu_); + auto it = create_fn_packs_.find(pass_name); + if (it == create_fn_packs_.end()) { + GELOGW("Scope fusion pass is not registered. pass name = %s.", pass_name.c_str()); + return false; + } + + CreatePassFnPack &create_fn_pack = it->second; + create_fn_pack.is_enable = flag; + GELOGI("enable flag of scope fusion pass:%s is set with %s.", pass_name.c_str(), flag ? "true" : "false"); + + return true; +} + +std::unique_ptr ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::CreateScopeFusionPass( + const std::string &pass_name) { + auto create_fn = GetCreateFn(pass_name); + if (create_fn == nullptr) { + GELOGD("Create scope fusion pass failed, pass name = %s.", pass_name.c_str()); + return nullptr; + } + GELOGI("Create scope fusion pass, pass name = %s.", pass_name.c_str()); + return std::unique_ptr(create_fn()); +} + +ScopeFusionPassRegistry::ScopeFusionPassRegistry() { + impl_ = std::unique_ptr(new (std::nothrow) ScopeFusionPassRegistryImpl); +} + +ScopeFusionPassRegistry::~ScopeFusionPassRegistry() {} + +void ScopeFusionPassRegistry::RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, + bool is_general) { + if (impl_ == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to register %s, ScopeFusionPassRegistry is not properly initialized.", + pass_name.c_str()); + return; + } + impl_->RegisterScopeFusionPass(pass_name, create_fn, is_general); +} + +void ScopeFusionPassRegistry::RegisterScopeFusionPass(const char *pass_name, CreateFn create_fn, + bool is_general) { + if (impl_ == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to register %s, ScopeFusionPassRegistry is not properly initialized.", + pass_name); + return; + } + std::string str_pass_name; + if (pass_name != nullptr) { + str_pass_name = pass_name; + } + impl_->RegisterScopeFusionPass(str_pass_name, create_fn, is_general); +} + +ScopeFusionPassRegistrar::ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), + bool is_general) { + if (pass_name == nullptr) { + GELOGE(PARAM_INVALID, "Failed to register scope fusion pass, pass name is null."); + return; + } + + ScopeFusionPassRegistry::GetInstance().RegisterScopeFusionPass(pass_name, create_fn, is_general); +} +} // namespace ge \ No newline at end of file diff --git a/metadef/register/scope/scope_pattern.cc b/metadef/register/scope/scope_pattern.cc new file mode 100644 index 00000000..26a02370 --- /dev/null +++ b/metadef/register/scope/scope_pattern.cc @@ -0,0 +1,464 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "register/scope/scope_pattern_impl.h" +#include "register/scope/scope_graph_impl.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_util.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/attr_utils.h" + +namespace ge { +namespace { +#define CHECK_NODE_ATTR_FEATURE_DATA(DTYPE, TYPE, FUNC_NAME, INIT_VALUE) \ + case DTYPE: { \ + TYPE value = INIT_VALUE; \ + if (!ge::AttrUtils::FUNC_NAME(op_desc, attr_name_, value)) { \ + GELOGE(ge::PARAM_INVALID, "op:%s %s attr is null", op_desc->GetName().c_str(), attr_name_.c_str()); \ + return false; \ + } \ + if (attr_value_.impl_->FUNC_NAME##Value() == value) { \ + GELOGI("NodeAttrFeature, match scope:%s", scope->Name().c_str()); \ + return true; \ + } \ + break; \ + } +} // namespace +ScopeAttrValue::ScopeAttrValue() { + impl_ = std::unique_ptr(new (std::nothrow) ScopeAttrValueImpl); +} + +ScopeAttrValue::ScopeAttrValue(ScopeAttrValue const &attr_value) { + impl_ = std::unique_ptr(new (std::nothrow) ScopeAttrValueImpl); + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopeAttrValue is not properly initialized."); + return; + } + impl_->SetIntValue(attr_value.impl_->GetIntValue()); + impl_->SetFloatValue(attr_value.impl_->GetFloatValue()); + impl_->SetStringValue(attr_value.impl_->GetStrValue()); + impl_->SetBoolValue(attr_value.impl_->GetBoolValue()); +} + +ScopeAttrValue &ScopeAttrValue::operator=(ScopeAttrValue const &attr_value) { + if (&attr_value == this) { + return *this; + } + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopeAttrValue is not properly initialized."); + return *this; + } + impl_->SetIntValue(attr_value.impl_->GetIntValue()); + impl_->SetFloatValue(attr_value.impl_->GetFloatValue()); + impl_->SetStringValue(attr_value.impl_->GetStrValue()); + impl_->SetBoolValue(attr_value.impl_->GetBoolValue()); + return *this; +} + +ScopeAttrValue::~ScopeAttrValue() {} + +void ScopeAttrValue::SetIntValue(int64_t value) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetIntValue(), ScopeAttrValue is not properly initialized."); + return; + } + impl_->SetIntValue(value); +} + +void ScopeAttrValue::SetFloatValue(float value) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetFloatValue(), ScopeAttrValue is not properly initialized."); + return; + } + impl_->SetFloatValue(value); +} + +void ScopeAttrValue::SetStringValue(std::string value) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetStringValue(), ScopeAttrValue is not properly initialized."); + return; + } + impl_->SetStringValue(value); +} + +void ScopeAttrValue::SetStringValue(const char *value) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetStringValue(), ScopeAttrValue is not properly initialized."); + return; + } + std::string str_value; + if (value != nullptr) { + str_value = value; + } + impl_->SetStringValue(str_value); +} + +void ScopeAttrValue::SetBoolValue(bool value) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetBoolValue(), ScopeAttrValue is not properly initialized."); + return; + } + impl_->SetBoolValue(value); +} + +bool NodeOpTypeFeature::NodeOpTypeFeatureImpl::Match(const Scope *scope) { + if (scope == nullptr) { + GELOGE(PARAM_INVALID, "Input scope is nullptr."); + return false; + } + auto &impl = scope->impl_; + + if (step_ == 0) { + if (impl->GetOpTypeNum(node_type_) == num_) { + GELOGI("NodeOpTypeFeature, node type:%s, num:%d, match scope:%s", + node_type_.c_str(), num_, scope->Name().c_str()); + return true; + } + } else { + if ((impl->GetOpTypeNum(node_type_) != -1) && (impl->GetOpTypeNum(node_type_) % step_ == num_)) { + GELOGI("NodeOpTypeFeature, node type:%s, num:%d, match scope:%s", + node_type_.c_str(), num_, scope->Name().c_str()); + return true; + } + } + + return false; +} + +NodeOpTypeFeature::NodeOpTypeFeature(std::string nodeType, int num, int step) { + impl_ = std::unique_ptr(new (std::nothrow) NodeOpTypeFeatureImpl(nodeType, num, step)); +} + +NodeOpTypeFeature::NodeOpTypeFeature(const char *node_type, int num, int step) { + std::string op_type; + if (node_type != nullptr) { + op_type = node_type; + } + impl_ = std::unique_ptr(new (std::nothrow) NodeOpTypeFeatureImpl(op_type, num, step)); +} + +NodeOpTypeFeature::NodeOpTypeFeature(NodeOpTypeFeature const &feature) { + impl_ = std::unique_ptr(new (std::nothrow) NodeOpTypeFeatureImpl(feature.impl_->node_type_, + feature.impl_->num_, + feature.impl_->step_)); +} + +NodeOpTypeFeature &NodeOpTypeFeature::operator=(NodeOpTypeFeature const &feature) { + if (&feature == this) { + return *this; + } + + impl_->node_type_ = feature.impl_->node_type_; + impl_->num_ = feature.impl_->num_; + impl_->step_ = feature.impl_->step_; + return *this; +} + +NodeOpTypeFeature::~NodeOpTypeFeature() {} + +bool NodeOpTypeFeature::Match(const Scope *scope) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), NodeOpTypeFeature is not properly initialized."); + return false; + } + + return impl_->Match(scope); +} + +bool NodeAttrFeature::NodeAttrFeatureImpl::Match(const Scope *scope) { + if (scope == nullptr) { + GELOGE(ge::PARAM_INVALID, "Input scope is nullptr."); + return false; + } + auto &impl = scope->impl_; + const std::vector &nodes = impl->Nodes(); + for (auto &node_op : nodes) { + if (node_type_ != node_op->GetOpType()) { + continue; + } + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*node_op); + if (op_desc == nullptr) { + GELOGE(ge::PARAM_INVALID, "Op desc is nullptr."); + return false; + } + + switch (datatype_) { + CHECK_NODE_ATTR_FEATURE_DATA(ge::DT_FLOAT, float, GetFloat, 0.0) + CHECK_NODE_ATTR_FEATURE_DATA(ge::DT_INT32, int64_t, GetInt, 0) + CHECK_NODE_ATTR_FEATURE_DATA(ge::DT_STRING, std::string, GetStr, "") + CHECK_NODE_ATTR_FEATURE_DATA(ge::DT_BOOL, bool, GetBool, false) + default: + break; + } + } + return false; +} + +NodeAttrFeature::NodeAttrFeature(std::string nodeType, std::string attr_name, + ge::DataType datatype, ScopeAttrValue &attr_value) { + impl_ = std::unique_ptr(new (std::nothrow) NodeAttrFeatureImpl(nodeType, attr_name, + datatype, attr_value)); +} + +NodeAttrFeature::NodeAttrFeature(const char *node_type, const char *attr_name, + ge::DataType data_type, ScopeAttrValue &attr_value) { + std::string str_node_type; + if (node_type != nullptr) { + str_node_type = node_type; + } + std::string str_attr_name; + if (attr_name != nullptr) { + str_attr_name = attr_name; + } + impl_ = std::unique_ptr(new (std::nothrow) NodeAttrFeatureImpl(str_node_type, str_attr_name, + data_type, attr_value)); +} + +NodeAttrFeature::NodeAttrFeature(NodeAttrFeature const &feature) { + impl_ = std::unique_ptr(new (std::nothrow) NodeAttrFeatureImpl(feature.impl_->node_type_, + feature.impl_->attr_name_, + feature.impl_->datatype_, + feature.impl_->attr_value_)); +} + +NodeAttrFeature &NodeAttrFeature::operator=(NodeAttrFeature const &feature) { + if (&feature == this) { + return *this; + } + impl_->node_type_ = feature.impl_->node_type_; + impl_->attr_name_ = feature.impl_->attr_name_; + impl_->datatype_ = feature.impl_->datatype_; + impl_->attr_value_ = feature.impl_->attr_value_; + return *this; +} + +NodeAttrFeature::~NodeAttrFeature() {} + +bool NodeAttrFeature::Match(const Scope *scope) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), NodeAttrFeature is not properly initialized."); + return false; + } + + return impl_->Match(scope); +} + +bool ScopeFeature::ScopeFeatureImpl::SubScopesMatch(const std::vector &scopes) { + int32_t count = 0; + bool sub_scope_name_matched = false; + for (auto &scp : scopes) { + if (sub_type_.length() > 0 && sub_type_ == scp->SubType()) { + ++count; + } + if (sub_scope_name_matched) { + continue; + } + auto &sub_impl = scp->impl_; + sub_scope_name_matched = (sub_scope_mask_.length() > 0) && + (sub_scope_mask_.length() < scp->Name().length()) && + (sub_impl->LastName().find(sub_scope_mask_) != std::string::npos); + } + + if ((sub_type_.length() > 0) && (step_ == 0) && (count != num_)) { + return false; + } + if ((sub_scope_mask_.length() > 0) && !sub_scope_name_matched) { + return false; + } + + return true; +} + +bool ScopeFeature::ScopeFeatureImpl::Match(const Scope *scope) { + auto &impl = scope->impl_; + std::string scope_name = scope->Name(); + if (suffix_.length() > scope_name.length()) { + return false; + } + if (suffix_.length() > 0) { + const std::string &last_name = impl->LastName(); + if (suffix_ != last_name) { + return false; + } + } + + const std::vector &scopes = impl->GetAllSubScopes(); + if (SubScopesMatch(scopes)) { + GELOGI("ScopeFeature, match scope:%s", scope->Name().c_str()); + return true; + } + + return false; +} + +ScopeFeature::ScopeFeature(std::string sub_type, int32_t num, std::string suffix, + std::string sub_scope_mask, int step) { + impl_ = std::unique_ptr(new (std::nothrow) ScopeFeatureImpl(sub_type, num, suffix, + sub_scope_mask, step)); +} + +ScopeFeature::ScopeFeature(const char *sub_type, int32_t num, const char *suffix, + const char *sub_scope_mask, int step) { + std::string str_sub_type; + if (sub_type != nullptr) { + str_sub_type = sub_type; + } + std::string str_suffix; + if (suffix != nullptr) { + str_suffix = suffix; + } + std::string str_sub_scope_mask; + if (sub_scope_mask != nullptr) { + str_sub_scope_mask = sub_scope_mask; + } + impl_ = std::unique_ptr(new (std::nothrow) ScopeFeatureImpl(str_sub_type, num, str_suffix, + str_sub_scope_mask, step)); +} + +ScopeFeature::ScopeFeature(ScopeFeature const &feature) { + impl_ = std::unique_ptr(new (std::nothrow) ScopeFeatureImpl(feature.impl_->sub_type_, + feature.impl_->num_, + feature.impl_->suffix_, + feature.impl_->sub_scope_mask_, + feature.impl_->step_)); +} + +ScopeFeature &ScopeFeature::operator=(ScopeFeature const &feature) { + if (&feature == this) { + return *this; + } + impl_->sub_type_ = feature.impl_->sub_type_; + impl_->num_ = feature.impl_->num_; + impl_->suffix_ = feature.impl_->suffix_; + impl_->sub_scope_mask_ = feature.impl_->sub_scope_mask_; + impl_->step_ = feature.impl_->step_; + return *this; +} + +ScopeFeature::~ScopeFeature() {} + +bool ScopeFeature::Match(const Scope *scope) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), ScopeFeature is not properly initialized."); + return false; + } + + return impl_->Match(scope); +} + +bool ScopePattern::ScopePatternImpl::Match(const Scope *scope) const { + if (scope == nullptr) { + GELOGE(PARAM_INVALID, "Input scope is nullptr."); + return false; + } + for (auto feature : node_optype_features_) { + if (!feature.Match(scope)) { + return false; + } + } + + for (auto feature : node_attr_features_) { + if (!feature.Match(scope)) { + return false; + } + } + + for (auto feature : scopes_features_) { + if (!feature.Match(scope)) { + return false; + } + } + + // If there is a _Retval node in the scope, the scope will not be fused. + NodeOpTypeFeature comm_node_feature = NodeOpTypeFeature("_Retval", -1, 0); + if (!comm_node_feature.Match(scope)) { + return false; + } + + return true; +} + +void ScopePattern::ScopePatternImpl::SetSubType(const std::string &sub_type) { + sub_type_ = sub_type; +} + +void ScopePattern::ScopePatternImpl::AddNodeOpTypeFeature(NodeOpTypeFeature &feature) { + node_optype_features_.push_back(feature); +} + +void ScopePattern::ScopePatternImpl::AddNodeAttrFeature(NodeAttrFeature &feature) { + node_attr_features_.push_back(feature); +} + +void ScopePattern::ScopePatternImpl::AddScopeFeature(ScopeFeature &feature) { + scopes_features_.push_back(feature); +} + +ScopePattern::ScopePattern() { + impl_ = std::unique_ptr(new (std::nothrow) ScopePatternImpl); +} + +ScopePattern::~ScopePattern() {} + +ScopePattern &ScopePattern::SetSubType(const std::string &sub_type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetSubType(), ScopePattern is not properly initialized."); + return *this; + } + impl_->SetSubType(sub_type); + return *this; +} + +ScopePattern &ScopePattern::SetSubType(const char *sub_type) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetSubType(), ScopePattern is not properly initialized."); + return *this; + } + std::string str_sub_type; + if (sub_type != nullptr) { + str_sub_type = sub_type; + } + impl_->SetSubType(str_sub_type); + return *this; +} + +ScopePattern &ScopePattern::AddNodeOpTypeFeature(NodeOpTypeFeature feature) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddNodeOpTypeFeature(), ScopePattern is not properly initialized."); + return *this; + } + impl_->AddNodeOpTypeFeature(feature); + return *this; +} + +ScopePattern &ScopePattern::AddNodeAttrFeature(NodeAttrFeature feature) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddNodeAttrFeature(), ScopePattern is not properly initialized."); + return *this; + } + impl_->AddNodeAttrFeature(feature); + return *this; +} + +ScopePattern &ScopePattern::AddScopeFeature(ScopeFeature feature) { + if (impl_ == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddScopeFeature(), ScopePattern is not properly initialized."); + return *this; + } + impl_->AddScopeFeature(feature); + return *this; +} +} // namespace ge diff --git a/metadef/register/scope/scope_util.cc b/metadef/register/scope/scope_util.cc new file mode 100644 index 00000000..5c7da9fb --- /dev/null +++ b/metadef/register/scope/scope_util.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "external/register/scope/scope_fusion_pass_register.h" + +#include + +#include "framework/common/debug/ge_log.h" +#include "framework/common/string_util.h" + +namespace ge { +std::string ScopeUtil::StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) { + return ge::StringUtils::ReplaceAll(str, old_value, new_value); +} + +AscendString ScopeUtil::StringReplaceAll(const char *str, const char *old_value, const char *new_value) { + std::string tmp_str; + if (str != nullptr) { + tmp_str = str; + } + std::string tmp_old_value; + if (old_value != nullptr) { + tmp_old_value = old_value; + } + std::string tmp_new_value; + if (new_value != nullptr) { + tmp_new_value = new_value; + } + std::string ret = ge::StringUtils::ReplaceAll(tmp_str, tmp_old_value, tmp_new_value); + return AscendString(ret.c_str()); +} + +void ScopeUtil::FreeScopePatterns(ScopeFusionPatterns &patterns) { + for (auto &batch_pattern : patterns) { + FreeOneBatchPattern(batch_pattern); + } + patterns.clear(); +} + +void ScopeUtil::FreeOneBatchPattern(std::vector &one_batch_pattern) { + for (auto &one_pattern : one_batch_pattern) { + if (one_pattern != nullptr) { + delete one_pattern; + one_pattern = nullptr; + } + } + one_batch_pattern.clear(); +} +} // namespace ge diff --git a/metadef/register/tensor_assign.cpp b/metadef/register/tensor_assign.cpp new file mode 100644 index 00000000..368f6038 --- /dev/null +++ b/metadef/register/tensor_assign.cpp @@ -0,0 +1,423 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include +#include +#include "securec.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_log.h" +#include "graph/debug/ge_util.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/attr_utils.h" +#include "register/register_error_codes.h" +#include "register/tensor_assign.h" + +using GeTensorDesc = ge::GeTensorDesc; +using GeShape = ge::GeShape; + +namespace domi { +namespace { +const uint32_t kExtraBytesForString = sizeof(int64_t) + 1; +const char *const kOriginElementNumAttrName = "origin_element_num"; +const std::map data_type_map = { + {domi::tensorflow::DataType::DT_FLOAT, ge::DataType::DT_FLOAT}, + {domi::tensorflow::DataType::DT_HALF, ge::DataType::DT_FLOAT16}, + {domi::tensorflow::DataType::DT_INT8, ge::DataType::DT_INT8}, + {domi::tensorflow::DataType::DT_INT16, ge::DataType::DT_INT16}, + {domi::tensorflow::DataType::DT_UINT16, ge::DataType::DT_UINT16}, + {domi::tensorflow::DataType::DT_UINT8, ge::DataType::DT_UINT8}, + {domi::tensorflow::DataType::DT_INT32, ge::DataType::DT_INT32}, + {domi::tensorflow::DataType::DT_INT64, ge::DataType::DT_INT64}, + {domi::tensorflow::DataType::DT_UINT32, ge::DataType::DT_UINT32}, + {domi::tensorflow::DataType::DT_UINT64, ge::DataType::DT_UINT64}, + {domi::tensorflow::DataType::DT_BOOL, ge::DataType::DT_BOOL}, + {domi::tensorflow::DataType::DT_DOUBLE, ge::DataType::DT_DOUBLE}, + {domi::tensorflow::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX64}, + {domi::tensorflow::DataType::DT_QINT8, ge::DataType::DT_INT8}, + {domi::tensorflow::DataType::DT_QUINT8, ge::DataType::DT_UINT8}, + {domi::tensorflow::DataType::DT_QINT32, ge::DataType::DT_INT32}, + {domi::tensorflow::DataType::DT_QINT16, ge::DataType::DT_INT16}, + {domi::tensorflow::DataType::DT_QUINT16, ge::DataType::DT_UINT16}, + {domi::tensorflow::DataType::DT_COMPLEX128, ge::DataType::DT_COMPLEX128}, + {domi::tensorflow::DataType::DT_RESOURCE, ge::DataType::DT_RESOURCE}, + {domi::tensorflow::DataType::DT_BFLOAT16, ge::DataType::DT_FLOAT16}, + {domi::tensorflow::DataType::DT_STRING, ge::DataType::DT_STRING}, + {domi::tensorflow::DataType::DT_FLOAT_REF, ge::DataType::DT_FLOAT}, + {domi::tensorflow::DataType::DT_DOUBLE_REF, ge::DataType::DT_DOUBLE}, + {domi::tensorflow::DataType::DT_INT32_REF, ge::DataType::DT_INT32}, + {domi::tensorflow::DataType::DT_INT8_REF, ge::DataType::DT_INT8}, + {domi::tensorflow::DataType::DT_UINT8_REF, ge::DataType::DT_UINT8}, + {domi::tensorflow::DataType::DT_INT16_REF, ge::DataType::DT_INT16}, + {domi::tensorflow::DataType::DT_UINT16_REF, ge::DataType::DT_UINT16}, + {domi::tensorflow::DataType::DT_COMPLEX64_REF, ge::DataType::DT_COMPLEX64}, + {domi::tensorflow::DataType::DT_QINT8_REF, ge::DataType::DT_INT8}, + {domi::tensorflow::DataType::DT_QUINT8_REF, ge::DataType::DT_UINT8}, + {domi::tensorflow::DataType::DT_QINT32_REF, ge::DataType::DT_INT32}, + {domi::tensorflow::DataType::DT_QINT16_REF, ge::DataType::DT_INT16}, + {domi::tensorflow::DataType::DT_QUINT16_REF, ge::DataType::DT_UINT16}, + {domi::tensorflow::DataType::DT_COMPLEX128_REF, ge::DataType::DT_COMPLEX128}, + {domi::tensorflow::DataType::DT_RESOURCE_REF, ge::DataType::DT_RESOURCE}, + {domi::tensorflow::DataType::DT_BFLOAT16_REF, ge::DataType::DT_FLOAT16}, + {domi::tensorflow::DataType::DT_UINT32_REF, ge::DataType::DT_UINT32}, + {domi::tensorflow::DataType::DT_UINT64_REF, ge::DataType::DT_UINT64}, + {domi::tensorflow::DataType::DT_INT64_REF, ge::DataType::DT_INT64}, + {domi::tensorflow::DataType::DT_BOOL_REF, ge::DataType::DT_BOOL}, + {domi::tensorflow::DataType::DT_HALF_REF, ge::DataType::DT_FLOAT16}, + {domi::tensorflow::DataType::DT_STRING_REF, ge::DataType::DT_STRING}, +}; +} // namespace + +ge::DataType TensorAssign::ConvertTensorflowDataType(uint32_t tf_data_type) { + auto search = data_type_map.find(tf_data_type); + if (search != data_type_map.end()) { + return search->second; + } else { + return ge::DataType::DT_UNDEFINED; + } +} + +bool TensorAssign::CheckBoolVal(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_BOOL) || (data_type == tensorflow::DT_BOOL_REF)); +} + +bool TensorAssign::CheckHalfVal(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_HALF) || (data_type == tensorflow::DT_BFLOAT16) || + (data_type == tensorflow::DT_HALF_REF) || (data_type == tensorflow::DT_BFLOAT16_REF)); +} + +bool TensorAssign::CheckFloatVal(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_FLOAT) || (data_type == tensorflow::DT_FLOAT_REF)); +} + +bool TensorAssign::CheckDoubleVal(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_DOUBLE) || (data_type == tensorflow::DT_DOUBLE_REF)); +} + +bool TensorAssign::CheckComplex64Val(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_COMPLEX64) || (data_type == tensorflow::DT_COMPLEX64_REF)); +} + +bool TensorAssign::CheckComplex128Val(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_COMPLEX128) || (data_type == tensorflow::DT_COMPLEX128_REF)); +} + +bool TensorAssign::CheckStringVal(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_STRING) || (data_type == tensorflow::DT_STRING_REF)); +} + +bool TensorAssign::CheckByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_UINT8) || (data_type == tensorflow::DT_INT8) || + (data_type == tensorflow::DT_QINT8) || (data_type == tensorflow::DT_QUINT8) || + (data_type == tensorflow::DT_UINT8_REF) || (data_type == tensorflow::DT_INT8_REF) || + (data_type == tensorflow::DT_QINT8_REF) || (data_type == tensorflow::DT_QUINT8_REF)); +} + +bool TensorAssign::CheckDoubleByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_INT16) || (data_type == tensorflow::DT_UINT16) || + (data_type == tensorflow::DT_QINT16) || (data_type == tensorflow::DT_QUINT16) || + (data_type == tensorflow::DT_INT16_REF) || (data_type == tensorflow::DT_UINT16_REF) || + (data_type == tensorflow::DT_QINT16_REF) || (data_type == tensorflow::DT_QUINT16_REF)); +} + +bool TensorAssign::CheckSignedFourByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_INT32) || (data_type == tensorflow::DT_QINT32) || + (data_type == tensorflow::DT_INT32_REF) || (data_type == tensorflow::DT_QINT32_REF)); +} + +bool TensorAssign::CheckUnsignedFourByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_UINT32) || (data_type == tensorflow::DT_UINT32_REF)); +} + +bool TensorAssign::CheckSignedEightByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_INT64) || (data_type == tensorflow::DT_INT64_REF)); +} + +bool TensorAssign::CheckUnsignedEightByte(tensorflow::DataType data_type) { + return ((data_type == tensorflow::DT_UINT64) || (data_type == tensorflow::DT_UINT64_REF)); +} + +Status TensorAssign::GetDoubleByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, + int count, GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + bool zerosLike = (count != val_size && val_size == 1); + uint16_t *addr = new (std::nothrow) uint16_t[count](); + GE_CHECK_NOTNULL(addr); + int minCount = (count > val_size) ? val_size : count; + if (!zerosLike) { + for (int32_t i = 0; i < minCount; i++) { + *(addr + i) = static_cast(val_vector.Get(i)); + } + for (int32_t i = minCount; i < count; i++) { + *(addr + i) = static_cast(val_vector.Get(minCount - 1)); + } + } else { + for (int32_t i = 0; i < count; i++) { + *(addr + i) = static_cast(val_vector.Get(0)); + } + } + weight->SetData(reinterpret_cast(addr), count * sizeof(uint16_t)); + GE_DELETE_NEW_ARRAY(addr); + return SUCCESS; +} + +Status TensorAssign::GetByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, + GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + bool zerosLike = (count != val_size && val_size == 1); + uint8_t *addr = new (std::nothrow) uint8_t[count](); + GE_CHECK_NOTNULL(addr); + int minCount = (count > val_size) ? val_size : count; + if (!zerosLike) { + for (int32_t i = 0; i < minCount; i++) { + *(addr + i) = static_cast(val_vector.Get(i)); + } + for (int32_t i = minCount; i < count; i++) { + *(addr + i) = static_cast(val_vector.Get(minCount - 1)); + } + } else { + for (int32_t i = 0; i < count; i++) { + *(addr + i) = static_cast(val_vector.Get(0)); + } + } + weight->SetData(addr, count * sizeof(uint8_t)); + GE_DELETE_NEW_ARRAY(addr); + return SUCCESS; +} + +Status TensorAssign::GetStringVal(int32_t val_size, const google::protobuf::RepeatedPtrField &val_vector, + int count, GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + bool flag = (count != val_size && val_size == 1); + int min_count = (count > val_size) ? val_size : count; + size_t total_size = 0; + if (!flag) { + for (int32_t i = 0; i < min_count; i++) { + // extra 8 bytes store pointer of string + // extra 1 byte store '\0' + total_size += (val_vector[i].size() + kExtraBytesForString); + } + total_size += (count - min_count) * kExtraBytesForString; + std::unique_ptr addr(new (std::nothrow) char[total_size]()); + GE_CHECK_NOTNULL(addr); + uint64_t *p = reinterpret_cast(addr.get()); + // front some bytes store pointer of each string + char *raw_data = addr.get() + count * sizeof(uint64_t); + for (int32_t i = 0; i < count; ++i) { + p[i] = reinterpret_cast(raw_data); + if (i < val_size) { + const string &str = val_vector.Get(i); + CHECK_FALSE_EXEC(memcpy_s(raw_data, str.size() + 1, str.c_str(), str.size() + 1) == EOK, + GELOGW("call memcpy_s fail!")); + raw_data += (str.size() + 1); + } else { + raw_data += 1; + } + } + weight->SetData(reinterpret_cast(addr.get()), total_size); + } else { + const string &str = val_vector.Get(0); + // extra 8 bytes store pointer of string + // extra 1 byte store '\0' + total_size = (str.size() + kExtraBytesForString) * count; + std::unique_ptr addr(new (std::nothrow) char[total_size]()); + GE_CHECK_NOTNULL(addr); + uint64_t *p = reinterpret_cast(addr.get()); + // front some bytes store pointer of each string + char *raw_data = addr.get() + count * sizeof(uint64_t); + for (int32_t i = 0; i < count; ++i) { + p[i] = reinterpret_cast(raw_data); + CHECK_FALSE_EXEC(memcpy_s(raw_data, str.size() + 1, str.c_str(), str.size() + 1) == EOK, + GELOGW("call memcpy_s fail!")); + raw_data += (str.size() + 1); + } + weight->SetData(reinterpret_cast(addr.get()), total_size); + } + return SUCCESS; +} + +void TensorAssign::SetGeTensorWeightData(const TensorProto &tensor, int32_t val_size, int count, GeTensorPtr &weight) { + tensorflow::DataType data_type = tensor.dtype(); + if (CheckFloatVal(data_type)) { + (void)GetVal(val_size, tensor.float_val(), count, weight); + } else if (CheckComplex64Val(data_type)) { + (void)GetVal(val_size, tensor.scomplex_val(), count, weight); + } else if (CheckSignedFourByte(data_type)) { + (void)GetVal(val_size, tensor.int_val(), count, weight); + } else if (CheckUnsignedFourByte(data_type)) { + (void)GetVal(val_size, tensor.uint32_val(), count, weight); + } else if (CheckSignedEightByte(data_type)) { + (void)GetVal(val_size, tensor.int64_val(), count, weight); + } else if (CheckUnsignedEightByte(data_type)) { + (void)GetVal(val_size, tensor.uint64_val(), count, weight); + } else if (CheckBoolVal(data_type)) { + (void)GetVal(val_size, tensor.bool_val(), count, weight); + } else if (CheckStringVal(data_type)) { + (void)GetStringVal(val_size, tensor.string_val(), count, weight); + } else if (CheckHalfVal(data_type)) { + (void)GetDoubleByteVal(val_size, tensor.half_val(), count, weight); + } else if (CheckDoubleByte(data_type)) { + (void)GetDoubleByteVal(val_size, tensor.int_val(), count, weight); + } else if (CheckByte(data_type)) { + (void)GetByteVal(val_size, tensor.int_val(), count, weight); + } else if (CheckDoubleVal(data_type)) { + (void)GetVal(val_size, tensor.double_val(), count, weight); + } else if (CheckComplex128Val(data_type)) { + (void)GetVal(val_size, tensor.dcomplex_val(), count, weight); + } else { + GELOGI("data_type:%s.", DataType_Name(data_type).c_str()); + } +} + +void TensorAssign::SetWeightData(tensorflow::DataType data_type, int count, const std::string &tensor_content, + GeTensorPtr &weight) { + if (weight == nullptr) { + GE_LOGE("weight is nullptr."); + return; + } + GELOGD("Set data from tensor_content, count = %d ,data_type = %s.", count, DataType_Name(data_type).c_str()); + if (CheckByte(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(uint8_t)); + } else if (CheckBoolVal(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(bool)); + } else if (CheckHalfVal(data_type) || CheckDoubleByte(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(uint16_t)); + } else if (CheckSignedFourByte(data_type) || CheckUnsignedFourByte(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(uint32_t)); + } else if (CheckSignedEightByte(data_type) || CheckUnsignedEightByte(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(uint64_t)); + } else if (CheckDoubleVal(data_type) || CheckComplex128Val(data_type)) { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(double)); + } else { + weight->SetData(reinterpret_cast(tensor_content.data()), count * sizeof(float)); + } +} + +Status TensorAssign::SetGeTensor(const TensorProto &tensor, GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + std::map datatype_val_size_map = { + {tensorflow::DT_FLOAT, tensor.float_val().size()}, + {tensorflow::DT_INT32, tensor.int_val().size()}, + {tensorflow::DT_INT64, tensor.int64_val().size()}, + {tensorflow::DT_BOOL, tensor.bool_val().size()}, + {tensorflow::DT_HALF, tensor.half_val().size()}, + {tensorflow::DT_INT8, tensor.int_val().size()}, + {tensorflow::DT_UINT8, tensor.int_val().size()}, + {tensorflow::DT_INT16, tensor.int_val().size()}, + {tensorflow::DT_UINT16, tensor.int_val().size()}, + {tensorflow::DT_DOUBLE, tensor.double_val().size()}, + {tensorflow::DT_STRING, tensor.string_val().size()}, + {tensorflow::DT_QINT8, tensor.int_val().size()}, + {tensorflow::DT_QINT16, tensor.int_val().size()}, + {tensorflow::DT_QINT32, tensor.int_val().size()}, + {tensorflow::DT_QUINT8, tensor.int_val().size()}, + {tensorflow::DT_QUINT16, tensor.int_val().size()}, + {tensorflow::DT_COMPLEX64, tensor.scomplex_val().size()}, + {tensorflow::DT_COMPLEX128, tensor.dcomplex_val().size()}, + {tensorflow::DT_BFLOAT16, tensor.half_val().size()}, + {tensorflow::DT_UINT32, tensor.uint32_val().size()}, + {tensorflow::DT_UINT64, tensor.uint64_val().size()}, + {tensorflow::DT_RESOURCE, tensor.resource_handle_val().size()}, + {tensorflow::DT_VARIANT, tensor.variant_val().size()}, + {tensorflow::DT_FLOAT_REF, tensor.float_val().size()}, + {tensorflow::DT_INT32_REF, tensor.int_val().size()}, + {tensorflow::DT_INT64_REF, tensor.int64_val().size()}, + {tensorflow::DT_BOOL_REF, tensor.bool_val().size()}, + {tensorflow::DT_HALF_REF, tensor.half_val().size()}, + {tensorflow::DT_INT8_REF, tensor.int_val().size()}, + {tensorflow::DT_UINT8_REF, tensor.int_val().size()}, + {tensorflow::DT_INT16_REF, tensor.int_val().size()}, + {tensorflow::DT_UINT16_REF, tensor.int_val().size()}, + {tensorflow::DT_DOUBLE_REF, tensor.double_val().size()}, + {tensorflow::DT_STRING_REF, tensor.string_val().size()}, + {tensorflow::DT_QINT8_REF, tensor.int_val().size()}, + {tensorflow::DT_QINT16_REF, tensor.int_val().size()}, + {tensorflow::DT_QINT32_REF, tensor.int_val().size()}, + {tensorflow::DT_QUINT8_REF, tensor.int_val().size()}, + {tensorflow::DT_QUINT16_REF, tensor.int_val().size()}, + {tensorflow::DT_COMPLEX64_REF, tensor.scomplex_val().size()}, + {tensorflow::DT_COMPLEX128_REF, tensor.dcomplex_val().size()}, + {tensorflow::DT_BFLOAT16_REF, tensor.half_val().size()}, + {tensorflow::DT_UINT32_REF, tensor.uint32_val().size()}, + {tensorflow::DT_UINT64_REF, tensor.uint64_val().size()}, + {tensorflow::DT_RESOURCE_REF, tensor.resource_handle_val().size()}, + {tensorflow::DT_VARIANT_REF, tensor.variant_val().size()}, + }; + tensorflow::DataType data_type = tensor.dtype(); + int32_t datatype_val_size = 0; + + auto iter = datatype_val_size_map.find(data_type); + if (iter != datatype_val_size_map.end()) { + datatype_val_size = iter->second; + } else { + GE_CHECK_GE(data_type, 0); + GE_LOGE("datatype:%s not support.", DataType_Name(data_type).c_str()); + return FAILED; + } + + std::vector shape_vec; + // There is tensor shape, get the dimension + int count = 1; + GE_IF_BOOL_EXEC( + tensor.has_tensor_shape(), const tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape(); + for (int i = 0; i < tensor_shape.dim_size(); i++) { + const tensorflow::TensorShapeProto_Dim &shape_dim = tensor_shape.dim(i); + shape_vec.push_back(shape_dim.size()); + int64_t dim = shape_vec[i]; + // tensorflow support weights shape [0],have no weights + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(dim < 0, return FAILED, "Dim size invalid"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((count != 0 && dim >= INT64_MAX / count), return FAILED, + "Dim size exceeds INT64_MAX"); + count *= dim; + }); + GeShape shape(shape_vec); + GeTensorDesc tmp_desc = weight->GetTensorDesc(); + tmp_desc.SetShape(shape); + + // Fixed input ND + tmp_desc.SetFormat(ge::Format::FORMAT_ND); + tmp_desc.SetOriginFormat(ge::Format::FORMAT_ND); + + weight->SetTensorDesc(tmp_desc); + + if (datatype_val_size > 0) { + SetGeTensorWeightData(tensor, datatype_val_size, count, weight); + int64_t origin_element_num = static_cast(datatype_val_size); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(weight->MutableTensorDesc(), kOriginElementNumAttrName, origin_element_num), + return FAILED, "Set origin element num failed."); + } else if (!tensor.tensor_content().empty()) { + const auto &tensor_content = tensor.tensor_content(); + SetWeightData(data_type, count, tensor_content, weight); + } else { + if (count == 0) { + GELOGI("Empty tensor, has no data."); + return SUCCESS; + } + GE_LOGE("value Attr tensor should have val() or tensor_content"); + return FAILED; + } + + return SUCCESS; +} + +Status TensorAssign::SetGeTensorDataType(int64_t data_type, GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + GeTensorDesc tmp_desc = weight->GetTensorDesc(); + tmp_desc.SetDataType(ge::DataType(data_type)); + weight->SetTensorDesc(tmp_desc); + return SUCCESS; +} +} // namespace domi diff --git a/metadef/third_party/fwkacllib/inc/cce/aicpu_engine.h b/metadef/third_party/fwkacllib/inc/cce/aicpu_engine.h new file mode 100644 index 00000000..b83731a8 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/aicpu_engine.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef AICPU_ENGINE_H__ +#define AICPU_ENGINE_H__ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + AE_STATUS_SUCCESS = 0, + AE_STATUS_BAD_PARAM = 1, + AE_STATUS_OPEN_SO_FAILED = 2, + AE_STATUS_GET_KERNEL_NAME_FAILED = 3, + AE_STATUS_INNER_ERROR = 4, + AE_STATUS_KERNEL_API_INNER_ERROR = 5, + AE_STATUS_END_OF_SEQUENCE = 6, + AE_STATUS_DUMP_FAILED = 7, + AE_STATUS_RESERVED +} aeStatus_t; + +/** + * @ingroup aicpu engine + * @brief aeCallInterface: + * a interface to call a function in a op kernfel lib + * @param [in] addr void *, should be STR_KERNEL * format + * @return aeStatus_t + */ +aeStatus_t aeCallInterface(void *addr); + +/** + * @ingroup aicpu engine + * @brief aeBatchLoadKernelSo: + * a interface to load kernel so + * @param [in] loadSoNum load so number + * @param [in] soPaths load so paths + * @param [in] soNames load so names + * @return aeStatus_t + */ +aeStatus_t aeBatchLoadKernelSo(const uint32_t loadSoNum, const char *soPaths[], const char *soNames[]); + +#ifdef __cplusplus +} +#endif + +#endif // AICPU_ENGINE_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h b/metadef/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h new file mode 100644 index 00000000..8c0c1847 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_ENGINE_STRUCT_H__ +#define AICPU_ENGINE_STRUCT_H__ + +#include "fwk_adpt_struct.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + The different framwork we adapted for. +*/ +typedef enum { + FMK_KERNEL_TYPE_TF = 0, + FMK_KERNEL_TYPE_CF = 10, + FMK_KERNEL_TYPE_PT = 20, + FMK_KERNEL_TYPE_RESERVED +} FwkkernelType_t; + +#pragma pack(push, 1) +typedef struct { + uint32_t fwkKernelType; // FwkkernelType_t + union { + ::aicpu::FWKAdapter::FWKOperateParam fwk_kernel; + } fwkKernelBase; +} STR_FWK_OP_KERNEL; +#pragma pack(pop) + +#pragma pack(push, 1) +struct SessionInfo { + uint64_t sessionId; + uint64_t kernelId; + bool sessFlag; +}; +#pragma pack(pop) + +#ifdef __cplusplus +} +#endif +#endif // AICPU_ENGINE_STRUCT_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/blas_struct.h b/metadef/third_party/fwkacllib/inc/cce/blas_struct.h new file mode 100644 index 00000000..e0bcee4c --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/blas_struct.h @@ -0,0 +1,31 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CC_BLAS_STRUCT_API__ +#define CC_BLAS_STRUCT_API__ + +#include + +typedef enum { CCBLAS_FILL_MODE_LOWER = 0, CCBLAS_FILL_MODE_UPPER = 1 } ccblasFillMode_t; + +typedef enum { + CCBLAS_OP_N = 0, + CCBLAS_OP_T = 1, +} ccblasOperation_t; + +typedef enum { CCBLAS_DIAG_NON_UNIT = 0, CCBLAS_DIAG_UNIT = 1 } ccblasDiagType_t; + +#endif // CC_BLAS_STRUCT_API__ diff --git a/metadef/third_party/fwkacllib/inc/cce/cce.h b/metadef/third_party/fwkacllib/inc/cce/cce.h new file mode 100644 index 00000000..0cd9613a --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/cce.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCE_H__ +#define CCE_H__ + +#include +#include "cce_def.hpp" + +namespace cce { + +/** + * @ingroup cce + * @brief create cc handler + * @param [in|out] handle point of cc handler + * @return ccStatus_t + */ +ccStatus_t ccCreate(ccHandle_t *handle); + +/** + * @ingroup cce + * @brief destroy cc handler + * @param [in] *handle cc handler + * @return ccStatus_t + */ +ccStatus_t ccDestroy(ccHandle_t *handle); + +/** + * @ingroup cce + * @brief bind stream with specified cc handler + * @param [in] handle cc handler + * @param [in] streamId stream + * @return ccStatus_t + */ +ccStatus_t ccSetStream(ccHandle_t handle, rtStream_t streamId); + +/** + * @ingroup cce + * @brief get the stream from cc handler + * @param [in] handle cc handler + * @param [in|out] streamId point of stream + * @return ccStatus_t + */ +ccStatus_t ccGetStream(ccHandle_t handle, rtStream_t *streamId); + +/** + * @ingroup cce + * @brief get the stream from cc handler + * @param [in] dataTypeTransMode mode of data type transform + * @param [in] inputData input data point + * @param [in] inputDataSize input data size + * @param [in|out] outputData output data point + * @param [in] outputDataSize output data size + * @return ccStatus_t + */ +ccStatus_t ccTransDataType(ccDataTypeTransMode_t dataTypeTransMode, const void *inputData, uint32_t inputDataSize, + void *outputData, const uint32_t outputDataSize); +/** + * @ingroup cce + * @brief cce sys init func + */ +void cceSysInit(); + +/** + * @ingroup cce + * @brief cce Log Start up func + */ +void cceLogStartup(); + +/** + * @ingroup cce + * @brief cce Log Shut down func + */ +void cceLogShutdown(); + +/** + * @ingroup cce + * @brief set the profiling on or off + * @param [in] const unsigned char* target: The engine gets it from ENV. Don't need care about it. + * @param const char* job_ctx: identifies profiling job + * @param [in] uint32_t flag: value: 0, on ; 1, off. + * @return ccStatus_t value: 0, success; 1, fail. + */ +ccStatus_t CceProfilingConfig(const char *target, const char *job_ctx, uint32_t flag); + +}; // namespace cce + +#endif // CCE_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/cce_def.hpp b/metadef/third_party/fwkacllib/inc/cce/cce_def.hpp new file mode 100644 index 00000000..7b1a1b8a --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/cce_def.hpp @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCE_DEF_H__ +#define CCE_DEF_H__ + +#include "runtime/rt.h" + +namespace cce { + +/** + * @ingroup cce + * @brief memory configure for fusion + */ +typedef struct TagCceFusionMemCfg { + uint64_t memAddr; /**< memAddr */ + uint32_t memSize; /**< memSize */ + uint32_t addrChangeFlag; /**< op data addr change flag. value:0,valid;1,not valid */ + uint32_t poolFlag; /**< mempool flag : value:0,is valid; value: 1, not valid */ + TagCceFusionMemCfg() { + memAddr = 0; + memSize = 0; + addrChangeFlag = 0; + poolFlag = 0; + } +} CceFusionMemCfg_t; +/** + * @ingroup cce + * @brief return value + */ +typedef enum tagCcStatus { + CC_STATUS_SUCCESS = 0, /**< succ */ + CC_STATUS_NOT_INITIALIZED = 1, /**< not init */ + CC_STATUS_ALLOC_FAILED = 2, /**< alloc mem failed */ + CC_STATUS_BAD_PARAM = 3, /**< para check failed */ + CC_STATUS_INTERNAL_ERROR = 4, /**< internal error */ + CC_STATUS_KERNEL_ERROR = 5, /**< kernel error */ + CC_STATUS_RUNTIME_ERROR = 6, /**< runtime error */ + CC_STATUS_NOT_SUPPORTED = 7, /**< unsupport error */ + CC_STATUS_INVALID_VALUE = 7, /**< invalid value error for blas*/ + CC_STATUS_RESERVED /**< just for check */ +} ccStatus_t; + +/** + * @ingroup cce + * @brief original data type + */ +typedef enum tagCcDataType { + CC_DATA_FLOAT = 0, /**< float type */ + CC_DATA_HALF, /**< fp16 type */ + CC_DATA_INT8, /**< int8 type */ + CC_DATA_INT32, /**< int32 type */ + CC_DATA_UINT8, /**< uint8 type */ + CC_DATA_HALF_UINT16_PROPOSAL, /** +#include + +#define ERROR_CODE() __catch_error_code +#define ERROR_LINE_NO() __catch_error_line_no +#define ERROR_PROC() __catch_error_line_no = __LINE__; + +#define PROC \ + uint32_t __catch_error_code = 0x7FFFFFCC; \ + uint32_t __catch_error_line_no = 0xFFFFFFFF; \ + { +#define END_PROC \ + } \ + __tabErrorCode: +#define THROW(errcode) \ + { \ + __catch_error_code = (errcode); \ + ERROR_PROC(); \ + goto __tabErrorCode; \ + } +#define EXEC(func) \ + { \ + if (0 != (__catch_error_code = (func))) THROW(__catch_error_code) \ + } +#define EXEC_EX1(func, error_code) \ + { \ + if (0 != (func)) THROW(error_code) \ + } +#define EXEC_EX(func, succRet, error_code) \ + { \ + if (succRet != (__catch_error_code = (func))) THROW(error_code) \ + } +#define ASSERT_EXEC(func, succRet) \ + { \ + if (succRet != (__catch_error_code = (func))) /*GO_ASSERT_FALSE();*/ \ + THROW(__catch_error_code) \ + } \ + } +#define NEW_ERROR_EXEC(errcode, func, succRet) \ + { \ + if (succRet != (func)) { \ + THROW(errcode) \ + } \ + } +#define JUDGE(errcode, expr) \ + { \ + if (!(expr)) { \ + THROW(errcode) \ + } \ + } +#define ASSERT_JUDGE(errcode, expr) \ + { \ + if (!(expr)) { /*GO_ASSERT_FALSE();*/ \ + THROW(errcode) \ + } \ + } +#define JUDGE_FALSE(errcode, expr) \ + { \ + if (expr) { \ + THROW(errcode) \ + } \ + } +#define JUDGE_CONTINUE(expr) \ + { \ + if (expr) { \ + continue; \ + } \ + } +#define CATCH_ERROR(errcode) if (__catch_error_code == (errcode)) { // ERROR_LOG(); +#define CATCH_ALL_ERROR { +#define END_CATCH_ERROR } +#define FINAL \ + __tabFinal: +#define END_FINAL /*GO_ASSERT_FALSE()*/ ; +#define GOTO_FINAL() goto __tabFinal; +#endif // CATCH_HPP_ diff --git a/metadef/third_party/fwkacllib/inc/cce/compiler_stub.h b/metadef/third_party/fwkacllib/inc/cce/compiler_stub.h new file mode 100644 index 00000000..00ea467e --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/compiler_stub.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPILER_STUB_H__ +#define COMPILER_STUB_H__ + +namespace cce { + +/** + * @ingroup cce + * @brief compiler stub init func + */ +bool compilerStubInit(); + +/** + * @ingroup cce + * @brief compiler stub free func + */ +bool compilerStubFree(); + +}; // namespace cce + +#endif // COMPILER_STUB_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/customize.h b/metadef/third_party/fwkacllib/inc/cce/customize.h new file mode 100644 index 00000000..7dd97af1 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/customize.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CC_CUSTOMIZE_API__ +#define CC_CUSTOMIZE_API__ + +#include + +#define CC_DEVICE_DIM_MAX 8 +typedef enum tagOpTensorFormat +{ + OP_TENSOR_FORMAT_NC1HWC0 = 0, + OP_TENSOR_FORMAT_ND, + OP_TENSOR_FORMAT_RESERVED, + +} opTensorFormat_t; + + +typedef enum tagOpDataType +{ + OP_DATA_FLOAT = 0, /**< float type */ + OP_DATA_HALF, /**< fp16 type */ + OP_DATA_INT8, /**< int8 type */ + OP_DATA_INT32, /**< int32 type */ + OP_DATA_UINT8, /**< uint8 type */ + OP_DATA_HALF_UINT16_PROPOSAL, /**dimCnt, xDesc->dimCnt) + * @param [in] num the number of outputs + * @param [in] beta scaling factors + * @param [in] yDescArr descriptors of output tensors + * @param [in|out] yArr output data array in device memory + * @return ccStatus_t + */ +ccStatus_t ccSplitForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + int32_t axis, uint32_t num, const void *beta, const ccTensorDescriptor_t yDescArr[], + void *yArr[]); + +/** + * @ingroup dnn + * @brief get the output dimensions info of split + * @param [in] xDesc descriptor of input tensor + * @param [in] axis the dimension along which to split. Must be in the range [-xDesc->dimCnt, xDesc->dimCnt) + * @param [in] num the number of outputs + * @param [in] sizes Optional, used to specify the sizes of each output tensor along split dim. The tensor x would + * be split evenly along split dim if sizes is NULL + * @param [in|out] nArr point to the first element of batch sizes + * @param [in|out] cArr point to the first element of channels + * @param [in|out] hArr point to the first element of heights of feature map + * @param [in|out] wArr point to the first element of widths of feature map + * @return ccStatus_t + */ +ccStatus_t ccGetSplitForwardOutputDim(const ccTensorDescriptor_t xDesc, int32_t axis, uint32_t num, + const uint32_t sizes[], uint32_t nArr[], uint32_t cArr[], uint32_t hArr[], + uint32_t wArr[]); + +/** + * @ingroup dnn + * @brief Get split output shape(s). + * @param [in] xDesc input tensor, support ND and NC1HWC0 + * @param [in] axis split axis, negtive axis will increased by dimCnt once time. + * @param [in] num splited nums. + * @param [in] sizes splited dim size on axis. if NULL was set, The input will be divided into num equally. + * @param [output] dimCnt splited dimCnt array. One to one correspondence with the splited output. + * @param [output] dim array of splited dim array. One to one correspondence with the splited output. + * @param [in| dimlen length of dim(Pass in the length of the entire space pointed to by dim, + not just the length of the dim array, because dim is a level 2 array + dimlen = lengthof dim[][], not just lengthof dim[]) + * @return ccStatus_t + */ +ccStatus_t ccGetSplitForwardOutputDim(const ccTensorDescriptor_t xDesc, int32_t axis, uint32_t num, + const uint32_t sizes[], int32_t *dimCnt, int32_t *dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief create weight compress info + * @param [in|out] compressInfo point to CompressInfo + * @return ccStatus_t + */ +ccStatus_t ccCreateWeightCompressInfo(ccWeightCompressInfo_t **compressInfo); + +/** + * @ingroup dnn + * @brief destory weight compress info + * @param [in] *compressInfo point to CompressInfo + * @return ccStatus_t + */ +ccStatus_t ccDestroyWeightCompressInfo(ccWeightCompressInfo_t **compressInfo); + +/** + * @ingroup dnn + * @brief create compress table + * @param [in|out] compressTab point to weight compress table + * @return ccStatus_t + */ +ccStatus_t ccCreateWeightCompressTab(ccWeightCompressTab_t **compressTab); + +/** + * @ingroup dnn + * @brief destory compress table + * @param [in] compressTab point to weight compress table + * @return ccStatus_t + */ +ccStatus_t ccDestroyWeightCompressTab(ccWeightCompressTab_t **compressTab); + +/** + * @ingroup dnn + * @brief get fc compress info + * @param [in] xDesc descriptor of input tensor + * @param [in] wDesc descriptor of weight tensor + * @param [in] biasDesc descriptor of bias tensor + * @param [in] dataTypeTransmode mode of data type transform + * @param [in] weightCompressInfo compress info, compute based on tiling method + * @param [in|out] outputSize output data size in byte + * @param [in|out] infoTabSize compress info table + * @return ccStatus_t + */ +ccStatus_t ccGetCompressedFcWeightInfo(const ccTensorDescriptor_t xDesc, const ccFilterDescriptor_t wDesc, + const ccTensorDescriptor_t biasDesc, ccDataTypeTransMode_t dataTypeTransmode, + ccWeightCompressInfo_t *weightCompressInfo, uint32_t *outputSize, + uint32_t *infoTabSize); +/** + * @ingroup dnn + * @brief compress fc + * @param [in] wDesc descriptor of weight tensor + * @param [in] w filter data in device memory + * @param [in] weightCompressInfo compress info, compute based on tiling method + * @param [in] dataTypeTransmode mode of data type transform + * @param [in|out] y output data in device memory + * @param [in] ySize transformed data size in byte + * @param [in|out] yCompressedSize compressed output data size in byte + * @param [in|out] infoTab compressed info table + * @param [in] infoTabSize compressed info table size in byte + * @return ccStatus_t + */ +ccStatus_t ccCompressWeight(const ccFilterDescriptor_t wDesc, const void *w, + const ccWeightCompressInfo_t *weightCompressInfo, ccDataTypeTransMode_t dataTypeTransmode, + ccFilterDescriptor_t yDesc, void *y, uint32_t ySize, uint32_t *yCompressedSize, + void *infoTab, uint32_t infoTabSize); + +/** + * @ingroup dnn + * @brief restore compressed fc data + * @param [in] x input data in device memory + * @param [in] xSizeInBytes input compressed weight data size in byte + * @param [in|out] y output data in device memory + * @param [in] ySizeInBytes output data size in byte + * @return ccStatus_t + */ +ccStatus_t ccRestoreCompressedWeight(const void *x, uint32_t xSizeInBytes, void *y, uint32_t ySizeInBytes, + rtMemcpyKind_t kind); + +/** + * @ingroup dnn + * @brief create quantize parameters struct + * @param [in|out] quantizeInfo descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccCreateQuantizeInfoTab(ccQuantizeDescriptor_t *quantizeInfo); + +/** + * @ingroup dnn + * @brief destroy quantize parameters struct + * @param [in] quantizeInfo descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccDestoryQuantizeInfoTab(ccQuantizeDescriptor_t *quantizeInfo); + +/** + * @ingroup dnn + * @brief set quantize parameters + * @param [in] quantizeInfo descriptor of quantize parameters + * @param [in] scaleValMode enmu type for quantize scale value type (normal or sqrt) + * @param [in] scale quantize scale value + * @param [in] offset quantize offset(when quantize algorithm is half offset or full offset,this should be + * configed) + * @param [in] offsetPad padding value for load3d (only for half offset or full offset) + * @return ccStatus_t + */ +ccStatus_t ccSetQuantizeFactors(ccQuantizeDescriptor_t quantizeInfo, ccScaleValueMode_t scaleValMode, + const uint16_t *scale, const uint16_t *offset, const uint8_t *offsetPad); + +/** + * @ingroup dnn + * @brief set Requantize parameters + * @param [in] quantizeInfo descriptor of quantize parameters + * @param [in] scaleValMode enmu type for requantize scale value type (normal or sqrt) + * @param [in] scale quantize scale value + * @param [in] offset quantize offset(when quantize algorithm is half offset or full offset,this should be + * configed) + * @param [in] offsetw offset for filter (only config for full offset quantize) + * @return ccStatus_t + */ +ccStatus_t ccSetReQuantizeFactors(ccQuantizeDescriptor_t quantizeInfo, ccScaleValueMode_t scaleValMode, + const uint16_t *scaleRq, const uint16_t *nextLayerOffset, const int32_t *offsetw); + +/** + * @ingroup dnn + * @brief set Dequantize parameters + * @param [in] quantizeInfo descriptor of quantize parameters + * @param [in] scaleValMode enmu type for dequantize scale value type (normal or sqrt) + * @param [in] scaleDq quantize scale value + * @param [in] offsetw offset for filter (only config for full offset quantize) + * @return ccStatus_t + */ +ccStatus_t ccSetDeQuantizeFactors(ccQuantizeDescriptor_t quantizeInfo, ccScaleValueMode_t scaleValMode, + const uint16_t *scaleDq, const int32_t *offsetw); + +/** + * @ingroup dnn + * @brief set convolution desciptor's quantize parameters + * @param [in] convDesc convolution descriptor + * @param [in] quantizeInfo descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionQuantizeInfo(ccConvolutionDescriptor_t convDesc, const ccQuantizeDescriptor_t QuantizeInfo); + +/** + * @ingroup dnn + * @brief set convolution desciptor's all offset quantize parameters + * @param [in] convDesc convolution descriptor + * @param [in] offsetw descriptor of quantize parameters + * @param [in] scaleReq descriptor of quantize parameters + * @param [in] offset_d_next descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccSetAllOffsetQuantizeFactors(ccQuantizeDescriptor_t quantizeInfo, const uint8_t *offsetW, + const uint8_t *offsetD, const uint16_t *scaleReq, const uint16_t *offsetDNext); + +/** + * @ingroup dnn + * @brief set full connection desciptor's quantize parameters + * @param [in] fcDesc full connection descriptor + * @param [in] quantizeInfo descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccSetFullConnectionQuantizeInfo(ccFullConnectionDescriptor_t fcDesc, + const ccQuantizeDescriptor_t QuantizeInfo); + +/** + * @ingroup dnn + * @brief set pooling desciptor's quantize parameters + * @param [in] poolingDesc pooling descriptor + * @param [in] quantizeInfo descriptor of quantize parameters + * @return ccStatus_t + */ +ccStatus_t ccSetPoolingQuantizeInfo(ccPoolingDescriptor_t poolingDesc, const ccQuantizeDescriptor_t QuantizeInfo); + +/** + * @ingroup dnn + * @brief set full connection desciptor's info table + * @param [in] fcDesc full connection descriptor + * @param [in] infoTabSize table size + * @param [in] infoTab pointer to info table + * @return ccStatus_t + */ +ccStatus_t ccSetFullConnectionDescriptor(ccFullConnectionDescriptor_t fcDesc, uint32_t infoTabSize, const void *infoTab, + ccFullConnectFwdAlgo_t algo = CC_FULLCONNECT_FWD_ALGO_HALF); + +/** + * @ingroup dnn + * @brief set full connection desciptor's relu flag + * @param [in] fcDesc full connection descriptor + * @param [in] opType operation type for append at convolution operation + * @param [in] opDesc operation descritpor for the opType + * @return ccStatus_t + */ +ccStatus_t ccFullConnectionAppendOp(ccFullConnectionDescriptor_t fcDesc, tagCcOpType opType, const void *opDesc); + +/** + * @ingroup dnn + * @brief check aipp basic info + * @param [in] inputFormat format of input image + * @param [in] loadStartPosH vertical start position in source image + * @param [in] loadStartPosW horizontal start position in source image + * @param [in] srcImageSizeH vertical size of source image + * @param [in] srcImageSizeW horizontal size of source image + * @param [in] cpaddingValue C direction padding value + * @param [in] cscSwitch csc enable or not + * @param [in] rbuvSwapSwitch swap R/U and B/V position of the image + * @param [in] axSwapSwitch swap RGBA->ARGB, YUVA->AYUV + * @param [in] singleLineMode when set this bit to 1, only read 1 line. Under this case, vertical size configuration is + * not useful. + * @return ccStatus_t + */ +ccStatus_t ccCheckConvolutionAippCommInfo(ccAippInputFormat_t inputFormat, int32_t loadStartPosW, int32_t loadStartPosH, + int32_t srcImageSizeW, int32_t srcImageSizeH, float cpaddingValue, + bool cscSwitch, bool rbuvSwapSwitch, bool axSwapSwitch, bool singleLineMode); + +/** + * @ingroup dnn + * @brief check aipp dtc info + * @param [in] dtcPixelMeanChnx Mean value for YUV or RGB data channel x + * @param [in] dtcPixelMinChnx Min value for YUV or RGB data channel x + * @param [in] dtcPixelVarReciChnx Reciprocal of variance or (max-min) for YUV or RGB data channel x + * @return ccStatus_t + */ +ccStatus_t ccCheckConvolutionAippDtcInfo(int32_t dtcPixelMeanChn0, int32_t dtcPixelMeanChn1, int32_t dtcPixelMeanChn2, + float dtcPixelMinChn0, float dtcPixelMinChn1, float dtcPixelMinChn2, + float dtcPixelVarReciChn0, float dtcPixelVarReciChn1, + float dtcPixelVarReciChn2); + +/** + * @ingroup dnn + * @brief check aipp pad info + * @param [in] paddingMode padding mode + * @param [in] leftPaddingSize left hblank/padding size + * @param [in] rightPaddingSize right hblank/padding size + * @param [in] topPaddingSize top padding size + * @param [in] bottomPaddingSize bottom padding size + * @return ccStatus_t + */ +ccStatus_t ccCheckConvolutionAippPadInfo(ccAippPaddingMode_t paddingMode, int32_t leftPaddingSize, + int32_t rightPaddingSize, int32_t topPaddingSize, int32_t bottomPaddingSize); + +/** + * @ingroup dnn + * @brief check aipp csc info + * @param [in] cscMatrixRmCn 3x3 CSC matrix for YUV to RGB or RGB to YUV, element of row m and column n + * @param [in] cscOutputBiasm output Bias for RGB to YUV, element of row m + * @param [in] cscInputBiasm input Bias for YUV to RGB, element of row m + * @return ccStatus_t + */ +ccStatus_t ccCheckConvolutionAippCscInfo(int32_t cscMatrixR0C0, int32_t cscMatrixR0C1, int32_t cscMatrixR0C2, + int32_t cscMatrixR1C0, int32_t cscMatrixR1C1, int32_t cscMatrixR1C2, + int32_t cscMatrixR2C0, int32_t cscMatrixR2C1, int32_t cscMatrixR2C2, + int32_t cscOutputBias0, int32_t cscOutputBias1, int32_t cscOutputBias2, + int32_t cscInputBias0, int32_t cscInputBias1, int32_t cscInputBias2); + +/** + * @ingroup dnn + * @brief check aipp scf info + * @param [in] scfSwitch scaling enable or not + * @param [in] scfInputW input width of scaling + * @param [in] scfInputH input height of scaling + * @param [in] scfOutputW output width of scaling + * @param [in] scfOutputH output height of scaling + * @return ccStatus_t + */ +ccStatus_t ccCheckConvolutionAippScfInfo(bool scfSwitch, int32_t scfInputW, int32_t scfInputH, int32_t scfOutputW, + int32_t scfOutputH); + +/** + * @ingroup dnn + * @brief check aipp param + * @param [in] convDesc descriptor of conv operator + * @param [in] xDesc input tensor info + * @param [in] yDesc output tensor info + * @return ccStatus_t + */ +ccStatus_t ccCheckConvFwdAippParam(const ccConvolutionDescriptor_t convDesc, const ccTensorDescriptor_t xDesc, + const ccTensorDescriptor_t yDesc); + +/** + * @ingroup dnn + * @brief init aipp basic info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] inputFormat format of input image + * @param [in] loadStartPosH vertical start position in source image + * @param [in] loadStartPosW horizontal start position in source image + * @param [in] srcImageSizeH vertical size of source image + * @param [in] srcImageSizeW horizontal size of source image + * @param [in] cpaddingValue C direction padding value + * @param [in] cscSwitch csc enable or not + * @param [in] rbuvSwapSwitch swap R/U and B/V position of the image + * @param [in] axSwapSwitch swap RGBA->ARGB, YUVA->AYUV + * @param [in] singleLineMode when set this bit to 1, only read 1 line. Under this case, vertical size configuration is + * not useful. + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippCommInfo(ccConvolutionDescriptor_t convDesc, ccAippInputFormat_t inputFormat, + int32_t loadStartPosW, int32_t loadStartPosH, int32_t srcImageSizeW, + int32_t srcImageSizeH, float cpaddingValue, bool cscSwitch, bool rbuvSwapSwitch, + bool axSwapSwitch, bool singleLineMode); +/** + * @ingroup dnn + * @brief init aipp dtc info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] dtcPixelMeanChnx Mean value for YUV or RGB data channel x + * @param [in] dtcPixelMinChnx Min value for YUV or RGB data channel x + * @param [in] dtcPixelVarReciChnx Reciprocal of variance or (max-min) for YUV or RGB data channel x + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippDtcInfo(ccConvolutionDescriptor_t convDesc, int32_t dtcPixelMeanChn0, + int32_t dtcPixelMeanChn1, int32_t dtcPixelMeanChn2, float dtcPixelMinChn0, + float dtcPixelMinChn1, float dtcPixelMinChn2, float dtcPixelVarReciChn0, + float dtcPixelVarReciChn1, float dtcPixelVarReciChn2); +/** + * @ingroup dnn + * @brief init aipp pad info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] paddingMode padding mode + * @param [in] leftPaddingSize left hblank/padding size + * @param [in] rightPaddingSize right hblank/padding size + * @param [in] topPaddingSize top padding size + * @param [in] bottomPaddingSize bottom padding size + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippPadInfo(ccConvolutionDescriptor_t convDesc, ccAippPaddingMode_t paddingMode, + int32_t leftPaddingSize, int32_t rightPaddingSize, int32_t topPaddingSize, + int32_t bottomPaddingSize); + +/** + * @ingroup dnn + * @brief init aipp csc info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] cscMatrixRmCn 3x3 CSC matrix for YUV to RGB or RGB to YUV, element of row m and column n + * @param [in] cscOutputBiasm output Bias for RGB to YUV, element of row m + * @param [in] cscInputBiasm input Bias for YUV to RGB, element of row m + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippCscInfo(ccConvolutionDescriptor_t convDesc, int32_t cscMatrixR0C0, int32_t cscMatrixR0C1, + int32_t cscMatrixR0C2, int32_t cscMatrixR1C0, int32_t cscMatrixR1C1, + int32_t cscMatrixR1C2, int32_t cscMatrixR2C0, int32_t cscMatrixR2C1, + int32_t cscMatrixR2C2, int32_t cscOutputBias0, int32_t cscOutputBias1, + int32_t cscOutputBias2, int32_t cscInputBias0, int32_t cscInputBias1, + int32_t cscInputBias2); + +/** + * @ingroup dnn + * @brief init aipp scf info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] scfSwitch scaling enable or not + * @param [in] scfInputW input width of scaling + * @param [in] scfInputH input height of scaling + * @param [in] scfOutputW output width of scaling + * @param [in] scfOutputH output height of scaling + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippScfInfo(ccConvolutionDescriptor_t convDesc, bool scfSwitch, int32_t scfInputW, + int32_t scfInputH, int32_t scfOutputW, int32_t scfOutputH); + +/** + * @ingroup dnn + * @brief set dynamic aipp parameter address and enflag info + * @param [in|out] convDesc descriptor of conv operator + * @param [in] dyncParaAddr aipp parameter address + * @param [in] dyncAippFlag flag to show whether to use dynamic aipp + * @return ccStatus_t + */ +ccStatus_t ccSetConvolutionAippDyncParaAddr(ccConvolutionDescriptor_t convDesc, const void *dyncParaAddr, + bool dyncAippFlag, bool rotationFlag = false); + +/** + * @ingroup dnn + * @brief check dynamic aipp parameter + * @param [in] dyncParaAddr aipp parameter address + * @param [in] dataLength parameter lenght + * @param [in] convolutionDimW convDimW + * @param [in] convolutionDimH convDimH + * @return ccStatus_t + */ +ccStatus_t ccCheckDynamicAippParam(const void *dynamicParamAddr, uint32_t dataLength, int64_t convolutionDimW, + int64_t convolutionDimH); + +/*** @ingroup dnn + * @brief trans mean and var + * @param [in|out] mean' = bnScale/sqrt(var) + * @param [in|out] var' = -bnScale * mean / sqrt(var) + bnBias + * @return ccStatus_t + */ + +ccStatus_t ccTransBatchnormMeanAndVar(void *mean, void *var, const ccTensorDescriptor_t bnScaleBiasMeanVarDesc, + const void *alpha, const void *beta, void *bnScale, void *bnBias, double epsilon); + +/** + * @ingroup dnn + * @brief init deconvolution adj or targetShape info. + * @param [in] convDesc conv descriptor. + * @param [in] adjH, adjust H output. + * @param [in] adjW, adjust W output. + * @param [in] targetShape, values of output shape, if this pointer was set, ignore adj. + * @return ccStatus_t + */ +ccStatus_t ccSetDeconvolutionOutShapeInfo(ccConvolutionDescriptor_t convDesc, uint32_t adjSize, const uint32_t *adj, + uint32_t targetShapeSize, const uint32_t *targetShape); + +/** + * @ingroup dnn + * @brief gather elements according to the indices. + * @param [in] alpha reserved. + * @param [in] xDesc description of the tensor from which to gather elements. + * @param [in] x data point of the tensor from which to gather elements. + * @param [in] indicesDesc description of the tensor of indices. + * @param [in] indices data point of the tensor of indices. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccGatherNdForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t indicesDesc, const void *indices, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of gather_nd. + * @param [in] xDesc description of the tensor from which to gather elements. + * @param [in] indicesDesc description of the tensor of indices. + * @param [output] n dim-size of n-dim. + * @param [output] c dim-size of c-dim. + * @param [output] h dim-size of h-dim. + * @param [output] w dim-size of w-dim. + * @param [output] realDimCnt real dim. + * @return ccStatus_t + */ +ccStatus_t ccGetGatherNdOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t indicesDesc, int32_t *n, + int32_t *c, int32_t *h, int32_t *w, int32_t *realDimCnt); +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetGatherNdOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t indicesDesc, + int32_t *dimCnt, int32_t *dim, int32_t dimLen); +/** + * @ingroup dnn + * @brief tile tensor by multiples. + * @param [in] alpha reserved. + * @param [in] xDesc description of the tensor which to be tiled. + * @param [in] x data point of the tensor which to be tiled. + * @param [in] multiples tile coefficient of each dim. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccTileForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccIntArray_t *multiples, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); + +/** + * @ingroup dnn + * @brief get output shape of tile. + * @param [in] xDesc description of the dividend tensor. + * @param [in] multiples multiples of each dim. + * @param [in|out] dimCnt [point to the output dimCnt] + * @param [in|out] dim [arrays to save dims] + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetTileOutputDim(const ccTensorDescriptor_t xDesc, const ccIntArray_t *multiples, int32_t *dimCnt, + int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief get output shape of tile. + * @param [in] xDesc description of the dividend tensor. + * @param [in] multiples multiples of each dim. + * @param [output] n dim-size of n-dim. + * @param [output] c dim-size of c-dim. + * @param [output] h dim-size of h-dim. + * @param [output] w dim-size of w-dim. + * @param [output] realDimCnt real dim. + * @return ccStatus_t + */ +ccStatus_t ccGetTileOutputDim(const ccTensorDescriptor_t xDesc, + // const ccIntArrayDescriptor_t multiples, + const ccIntArray_t *multiples, int32_t *n, int32_t *c, int32_t *h, int32_t *w, + int32_t *realDimCnt); +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetRealdivOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the dividend tensor. + * @param [in] x data point of the dividend tensor. + * @param [in] yDesc description of the divisor tensor. + * @param [in] y data point of the divisor tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccRealdivForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the dividend tensor. + * @param [in] yDesc description of the divisor tensor. + * @param [output] n dim-size of n-dim. + * @param [output] c dim-size of c-dim. + * @param [output] h dim-size of h-dim. + * @param [output] w dim-size of w-dim. + * @param [output] realDimCnt real dim. + * @return ccStatus_t + */ +ccStatus_t ccGetRealdivOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *n, + int32_t *c, int32_t *h, int32_t *w, int32_t *realDimCnt); + +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccFloordivForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] realDimCnt real dim. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetFloordivOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccGreaterForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetGreaterOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccLessForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetLessOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief get output shape of LogicalOr. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetLogicalOrOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief get output shape of LogicalXor. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in] dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetLogicalXorOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief sqrt forward: + * data type only support bool + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccLogicalNotForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief equal between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ + +ccStatus_t ccEqualForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief dump data during inference, only for eng ver. + * @param [in] handle cce handle + * @return ccStatus_t + */ +ccStatus_t ccDataDumpForward(ccHandle_t handle, const void *buffer, const uint64_t bufLen, const uint32_t taskIndex); + +/** + * @ingroup dnn + * @brief logicaland between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccLogicalAndForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief logical or between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccLogicalOrForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief logical Xor between two tensors(x ^ y = (x | y) & ~(x & y). + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccLogicalXorForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of equal. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetEqualOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); +/** + * @ingroup dnn + * @brief get output shape of logicaland. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetLogicalAndOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccFloormodForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetFloormodOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief compare between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ +ccStatus_t ccCompareForward(ccHandle_t handle, ccCompareType_t compareType, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const ccTensorDescriptor_t yDesc, + const void *y, const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [output] dimCnt dim nums. + * @param [output] dim dim size. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetCompareOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief create descriptor of FillParam + * @param [in|out] fillParamDesc point to descriptor of fill param + * @return ccStatus_t + */ +ccStatus_t ccCreateFillParamDescriptor(ccFillParamDescriptor_t *fillParamDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of FillParam + * @param [in] *fillParamDesc point to descriptor of fill param + * @return ccStatus_t + */ +ccStatus_t ccDestroyFillParamDescriptor(ccFillParamDescriptor_t *fillParamDesc); + +/** + * @ingroup dnn + * @brief get output shape of broadcat operations. + * @param [in] inputNum input number of the operation tensors. + * @param [in] xDesc[] description of the input operation tensors list. + * @param [output] dimCnt dim-size of output tensor. + * @param [output] dim dim of output tensor. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetMultiNdBroadcastOpOutputDim(const int32_t inputNum, const ccTensorDescriptor_t xDesc[], int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief get output shape of maximultitensor. + * @param [in] inputNum the num of input operator tensors. + * @param [in] xDesc[] description of the input operator tensors list. + * @param [output] dimCnt dim count of output tensor. + * @param [output] dim array of output tensor. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetMaxMultitensorOutputDim(const int32_t inputNum, const ccTensorDescriptor_t xDesc[], int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief get output shape of minmultitensor. + * @param [in] inputNum the num of input operator tensors. + * @param [in] xDesc[] description of the input operator tensors list. + * @param [output] dimCnt dim count of output tensor. + * @param [output] dim array of output tensor. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetMinMultitensorOutputDim(const int32_t inputNum, const ccTensorDescriptor_t xDesc[], int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief MaxMultitensor forward: + * data type only support float float16 and int32 + * data format only support ND + * @param [in] handle cce handle + * @param [in] inputNum input tensor number + * @param [in] alpha common scale factor + * @param [in] xDesc[] descriptor of input tensors list + * @param [in] x[] input data in device memory list + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccMaxMultitensorForward(const ccHandle_t handle, const int32_t inputNum, const void *alpha, + const ccTensorDescriptor_t xDesc[], const void *x[], const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief MinMultitensor forward: + * data type only support float float16 and int32 + * data format only support ND + * @param [in] handle cce handle + * @param [in] inputNum input tensor number + * @param [in] alpha common scale factor + * @param [in] xDesc[] descriptor of input data list + * @param [in] x[] input data in device memory list + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccMinMultitensorForward(const ccHandle_t handle, const int32_t inputNum, const void *alpha, + const ccTensorDescriptor_t xDesc[], const void *x[], const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief create descriptor of StridedSlice + * @param [in|out] stridedSliceDesc point to descriptor of StridedSlice param + * @return ccStatus_t + */ +ccStatus_t ccCreateStridedSliceDescriptor(ccStridedSliceDescriptor_t *stridedSliceDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of StridedSlice + * @param [in] *stridedSliceDesc point to descriptor of StridedSlice param + * @return ccStatus_t + */ +ccStatus_t ccDestroyStridedSliceDescriptor(ccStridedSliceDescriptor_t *stridedSliceDesc); + +/** + * @ingroup dnn + * @brief init stridedSlice descriptor_t. + * @param [out] stridedSliceDesc struct of stridedslice param + * @param [in] dimCnt dimension of the input tensor + * @param [in] begin slice begin(include) + * @param [in] end slice end index(not include) + * @param [in] strides slice stride + * @return ccStatus_t + */ +ccStatus_t ccSetStridedSliceDescriptor(ccStridedSliceDescriptor_t stridedSliceDesc, int32_t dimCnt, int32_t begin[], + int32_t end[], int32_t strides[]); + +/** + * @ingroup dnn + * @brief create descriptor of StridedSlice + * @param [in|out] stridedSliceDesc point to descriptor of StridedSlice attr + * @return ccStatus_t + */ +ccStatus_t ccCreateStridedSliceAttrsDescriptor(ccStridedSliceAttrsDescriptor_t *attrDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of StridedSlice + * @param [in] *stridedSliceDesc point to descriptor of StridedSlice attr + * @return ccStatus_t + */ +ccStatus_t ccDestroyStridedSliceAttrsDescriptor(ccStridedSliceAttrsDescriptor_t *attrDesc); + +/** + * @ingroup dnn + * @brief init stridedSlice mask attrs desescriptor. + * @param [out] attrDesc struct of stridedslice mask attrs + * @param [in] beginMask begin mask + * @param [in] endMask end mask + * @param [in] ellipsisMask ellipsis mask + * @param [in] newAxisMask new axis mask + * @param [in] shrinkAxisMask shrink axis mask + * @return ccStatus_t + */ +ccStatus_t ccSetStridedSliceAttrsDescriptor(ccStridedSliceAttrsDescriptor_t attrDesc, int32_t beginMask, + int32_t endMask, int32_t ellipsisMask, int32_t newAxisMask, + int32_t shrinkAxisMask); + +/** + * @ingroup dnn + * @brief Extracts a strided slice of a tensor. + * @param [in] xDesc descriptor of input data + * @param [in] stridedSliceDesc specifies the begin, end, strides of slice + * @param [in] attrDesc reserve for optional attributes. + * @param [out] n point to n size + * @param [out] c point to c size + * @param [out] h point to h size + * @param [out] w point to w size + * @return ccStatus_t + */ +ccStatus_t ccGetStridedSliceOutputDim(const ccTensorDescriptor_t xDesc, + const ccStridedSliceDescriptor_t stridedSliceDesc, + const ccStridedSliceAttrsDescriptor_t attrDesc, int32_t *n, int32_t *c, + int32_t *h, int32_t *w, int32_t *realDimCnt); + +/** + * @ingroup dnn + * @brief Extracts a strided slice of a tensor. + * @param [in] handle cce handle + * @param [in] stridedSliceDesc specifies the begin, end, strides of slice + * @param [in] attrDesc reserve for optional attributes. + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] yDesc descriptor of output data + * @param [in|out] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccStridedSliceForward(ccHandle_t handle, const ccStridedSliceDescriptor_t stridedSliceDesc, + const ccStridedSliceAttrsDescriptor_t attrDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ + * @brief get out put descrition of slice tensor. + * @param [in] xDesc descriptor of input data + * @param [in] begin begin position of tensor + * @param [in] size size to slice + * @param [out] n point to n size + * @param [out] c point to c size + * @param [out] h point to h size + * @param [out] w point to w size + * @param [out] realDimCnt realdim count + * @return ccStatus_t + */ +ccStatus_t ccGetSliceOutputDim(const ccTensorDescriptor_t xDesc, const ccIntArray_t *begin, const ccIntArray_t *size, + int32_t *n, int32_t *c, int32_t *h, int32_t *w, int32_t *realDimCnt); + +/** + * @ingroup dnn + * @brief slice of a tensor. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] begin begin position of tensor + * @param [in] size size to slice + * @param [in] beta common scale factor + * @param [in] yDesc descriptor of output data + * @param [in|out] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccSliceForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccIntArray_t *begin, const ccIntArray_t *size, const void *beta, + const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief gather forward computation + * @param [in] handle cce handle + * @param [in] paramsDesc descriptor of params tensor + * @param [in] params input data in device memory + * @param [in] indicesDesc descriptor of indices tensor + * @param [in] indices indices data in device memory + * @param [in] axis descriptor of roi tensor + * @param [in] alpha reserved + * @param [in] beta reserved + * @param [in] outputDesc descriptor of output tensor + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccGatherForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t paramsDesc, + const void *params, const ccTensorDescriptor_t indicesDesc, const void *indices, + const int32_t axis, const void *beta, ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief gather output dim computation, for NC1HWC0 + * @param [in] paramsDesc descriptor of params tensor + * @param [in] indicesDesc descriptor of indices tensor + * @param [in] axis descriptor of roi tensor + * @param [out] n dim of n + * @param [out] c dim of c + * @param [out] h dim of h + * @param [out] w dim of w + * @param [out] realDimCnt real dim count + * @return ccStatus_t + */ +ccStatus_t ccGetGatherOutputDim(const ccTensorDescriptor_t paramsDesc, const ccTensorDescriptor_t indicesDesc, + int32_t axis, int32_t *n, int32_t *c, int32_t *h, int32_t *w, int32_t *realDimCnt); + +/** + * @ingroup dnn + * @brief gather output dim computation + * @param [in] paramsDesc descriptor of params tensor + * @param [in] indicesDesc descriptor of indices tensor + * @param [in] axis descriptor of roi tensor + * @param [out] dimCnt dimcnt of output + * @param [out] dim dim of output + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetGatherOutputDim(const ccTensorDescriptor_t paramsDesc, const ccTensorDescriptor_t indicesDesc, + int32_t axis, int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief exp forward computation + * @param [in] handle cce handle + * @param [in] expDesc descriptor of expParam + * @param [in] expParam a ternary array + * @param [in] alpha reserved parameter + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta reserved parameter + * @param [in] yDesc descriptor of output tensor + * @param [out] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccExpForward(ccHandle_t handle, const ccExpDescriptor_t expDesc, const void *expParam, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief expm1 forward: + * data type only support float float16 and double + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccExpm1Forward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief log1p forward: + * data type only support float float16 and double + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccLog1pForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief init descriptor for parameter of exp function + * @param [in|out] powDesc descriptor of tensor + * @param [in] dataType data type in device + * @param [in] paramCnt number of parameters + * @return ccStatus_t + */ +ccStatus_t ccSetExpDescriptor(ccExpDescriptor_t expDesc, ccDataType_t dataType, uint32_t paramCnt); + +/** + * @ingroup dnn + * @brief exp forward computation + * @param [in] handle cce handle + * @param [in] logDesc descriptor of logParam + * @param [in] logParam a ternary array + * @param [in] alpha reserved parameter + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta reserved parameter + * @param [in] yDesc descriptor of output tensor + * @param [in] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccLogForward(ccHandle_t handle, const ccLogDescriptor_t logDesc, const void *logParam, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief init descriptor for parameter of log function + * @param [in|out] logDesc descriptor of tensor + * @param [in] dataType data type in device + * @param [in] paramCnt number of parameters + * @return ccStatus_t + */ +ccStatus_t ccSetLogDescriptor(ccLogDescriptor_t logDesc, ccDataType_t dataType, uint32_t paramCnt); + +/** + * @ingroup dnn + * @brief pow forward computation + * @param [in] handle cce handle + * @param [in] powDesc descriptor of logParam + * @param [in] powParam a ternary array + * @param [in] alpha reserved parameter + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta reserved parameter + * @param [in] yDesc descriptor of input tensor + * @param [in] y input data in device memory + * @param [in] zDesc descriptor of output tensor + * @param [out] z output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccPowForward(ccHandle_t handle, const ccPowDescriptor_t powDesc, const void *powParam, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const ccTensorDescriptor_t yDesc, + const void *y, const void *beta, const ccTensorDescriptor_t zDesc, void *z); + +/** + * @brief init descriptor for parameter of pow function + * @param [in|out] powDesc descriptor of tensor + * @param [in] dataType data type in device + * @param [in] paramCnt number of parameters + * @return ccStatus_t + */ +ccStatus_t ccSetPowDescriptor(ccPowDescriptor_t powDesc, ccDataType_t dataType, uint32_t paramCnt); + +/** + * @ingroup dnn + * @brief non max suppression forward. + * @param [in] handle cce handle + * @param [in] nonmaxParaDesc descriptor of para + * @param [in] nonmaxPara input para in host memory + * @param [in] maxoutputsizex input para in host memory + * @param [in] alpha common scale factor + * @param [in] boxesDesc descriptor of input data boxesDesc + * @param [in] boxes input data boxes in device memory + * @param [in] scoresDesc descriptor of input data boxesDesc + * @param [in] scores input data scores in device memory + * @param [in] workSpaceSizeInBytes workspace size + * @param [in] workSpace input workspace in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccNonMaxSuppressionForward(ccHandle_t handle, const ccNonMaxSuppressionDescriptor_t nonmaxParaDesc, + const void *nonmaxPara, const int *maxoutputsize, const void *alpha, + const ccTensorDescriptor_t boxesDesc, const void *boxes, + const ccTensorDescriptor_t scoresDesc, const void *scores, + const uint32_t workSpaceSizeInBytes, void *workSpace, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); +/** + * @brief init descriptor for parameter of NonMaxSuppression function + * @param [in|out] powDesc descriptor of tensor + * @param [in] dataType data type in device + * @param [in] paramCnt number of parameters + * @return ccStatus_t + */ +ccStatus_t ccSetNonMaxSuppressionDescriptor(ccNonMaxSuppressionDescriptor_t nonMaxSuppressionDesc, + ccDataType_t dataType, uint32_t paramCnt); + +/** + * @ingroup dnn + * @brief get the output dimension info of resizeBilinear op. + * @param [in] xDesc descriptor of input data + * @param [in] resizeBilinearDesc descriptor of resize_bilinear operator + * @param [out] dimCnt + * @param [out] dim[] dim of output + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetResizeBilinearOutputDim(const ccTensorDescriptor_t xDesc, + const ccResizeBilinearDescriptor_t resizeBilinearDesc, int32_t *dimCnt, + int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief get the output dimension info of interp op. + * @param [in] xDesc descriptor of input data + * @param [in] resizeBilinearDesc descriptor of resize_bilinear operator + * @param [out] dimCnt + * @param [out] dim[] dim of output + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetInterpOutputDim(const ccTensorDescriptor_t xDesc, const ccResizeBilinearDescriptor_t resizeBilinearDesc, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); +/** + * @ingroup dnn + * @brief resize bilinear forward for t network. + * @param [in] handle cce handle + * @param [in] resizeBilinearDesc descriptor of resize_bilinear operator + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] yDesc descriptor of output data + * @param [in|out] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccResizeBilinearForward(ccHandle_t handle, const ccResizeBilinearDescriptor_t resizeBilinearDesc, + const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief resize bilinear forward for c network. + * @param [in] handle cce handle + * @param [in] resizeBilinearDesc descriptor of resize_bilinear operator + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] yDesc descriptor of output data + * @param [in|out] y output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccInterpForward(ccHandle_t handle, const ccResizeBilinearDescriptor_t resizeBilinearDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief create descriptor of ResizeBilinear + * @param [in|out] resizeBilinearDesc point to descriptor of resizeBilinear attr + * @return ccStatus_t + */ +ccStatus_t ccCreateResizeBilinearDescriptor(ccResizeBilinearDescriptor_t *resizeBilinearDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of Interp + * @param [in|out] resizeBilinearDesc point to descriptor of resizeBilinear attr + * @return ccStatus_t + */ +ccStatus_t ccDestroyResizeBilinearDescriptor(ccResizeBilinearDescriptor_t *resizeBilinearDesc); + +/** + * @ingroup dnn + * @brief set descriptor of resizeBilinear. + * @param [in|out] resizeBilinearDesc descriptor of resize_bilinear operator + * @param [in] resizeOutputDimMode way to decide output dimensions + * @param [in] alignCorners whether the centers of input and output are aligned + * @param [in] zoom_factor zoom factor + * @param [in] shrink_factor shrink factor + * @param [in] height height of output + * @param [in] width width of output + * @param [in] pad_begin padding at begin of input + * @param [in] pad_end padding at end of input + * @return ccStatus_t + */ +ccStatus_t ccSetResizeBilinearDescriptor(ccResizeBilinearDescriptor_t resizeBilinearDesc, + ccResizeOutputDimMode_t resizeOutputDimMode, bool alignCorners, + int32_t zoom_factor, int32_t shrink_factor, int32_t height, int32_t width, + int32_t pad_begin, int32_t pad_end); + +/** + * @ingroup dnn + * @brief fill forward computation + * @param [in] handle cce handle + * @param [in] fillParamDesc descriptor of fill parameter + * @param [in] alpha reserved + * @param [in] givenDesc descriptor of given tensor + * @param [in] givenData given data in device memory + * @param [in] workspace space for fill algorithm + * @param [in] workSpaceSizeInBytes space size in byte + * @param [in] beta reserved + * @param [in] outputDesc descriptor of output tensor + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccFillForward(ccHandle_t handle, const ccFillParamDescriptor_t fillParamDesc, const void *alpha, + const ccTensorDescriptor_t givenDesc, const void *givenData, const void *workspace, + const uint32_t workSpaceSizeInBytes, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); + +/** + * @ingroup dnn + *[ccGetFillWorkspaceSize] + *@param fillType [fill type] + *@param givenDesc [given tensor descriptor] + *@param xDesc [input tensor descriptor] + *@param sizeInBytes [output size] + *@return ccStatus_t [status] + */ +ccStatus_t ccGetFillWorkspaceSize(const ccFillOpType_t fillType, const ccTensorDescriptor_t xDesc, + uint32_t *sizeInBytes); + +/** + *[ccCast] + *@param handle [cce handler] + *@param alpha [alpha] + *@param xDesc [tensor Description of tensor x] + *@param x [input tensor x] + *@param beta [beta + *@param yDesc [tensor Description of tensor y] + *@param y [output tensor y] + *@return ccStatus_t [status] + */ +ccStatus_t ccCast(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief round forward: + * data type only support float float16 and int32 + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccRoundForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief rint forward: + * data type only support float float16 + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccRintForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief sqrt forward: + * data type only support float float16 + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccSqrtForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + *[ccCast] + *@param filterSrcInfo [cce filtersrc descriptor] + *@param filterSrc [filterSrc address] + *@param filterDstInfo [cce filterdst descriptor] + *@param filterDst [filterdst address] + *@param group [group] + *@param ySizeInBytes [fraczfilter size] + *@param outputDataType [datatype] + *@return ccStatus_t [status] + */ +ccStatus_t ccTransGroupConvFilterInt8(ccFilterDescriptor_t filterSrcInfo, const void *filterSrc, + ccFilterDescriptor_t filterDstInfo, void *filterDst, uint32_t group, + uint32_t ySizeInBytes, ccDataType_t outputDataType); + +/** + *[ccGetConcatOutputDim] + *@param xDesc[] [input tensor descriptor] + *@param axis [concat axis] + *@param inputNum [input tensor numbers] + *@param dim[] [output dim] + *@param [in| dimlen length of dim + *@return ccStatus_t [status] + */ +ccStatus_t ccGetConcatOutputDim(const ccTensorDescriptor_t xDesc[], int32_t axis, int32_t inputNum, int32_t *dimCnt, + int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief get the output dimension info of reduce. + * @param [in] xDesc descriptor of input tensor + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetReduceOutputDim(const ccTensorDescriptor_t xDesc, const ccIntArray_t *axis, bool keepDims, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief reduce sum forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceSumForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce max forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceMaxForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce min forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceMinForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce mean forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceMeanForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce prod forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceProdForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce all forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceAllForward(ccHandle_t handle, const ccIntArray_t *axis, bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + *@brief print times stats + *@return ccStatus_t [status] + */ +ccStatus_t ccPrintTimeStat(); + +/** + * @ingroup dnn + * @brief reduce abs sum forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceAbsSumForward(ccHandle_t handle, const ccIntArray_t *axis, const bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reduce square sum forward computation + * @param [in] handle cce handle + * @param [in] axis The dimensions to reduce + * @param [in] keepDims If true, retains reduced dimensions with length 1. + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReduceSquareSumForward(ccHandle_t handle, const ccIntArray_t *axis, const bool keepDims, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get the output dimension info of crop and resize + * @param [in] imageDesc descriptor of images + * @param [in] boxesDesc descriptor of boxes + * @param [in] boxidxDesc descriptor of boxidx + * @param [in] resizeHeight resize height + * @param [in] resizeWidth resize width + * @param [out] dimCnt dimcnt of output + * @param [out] dim dim of output + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetCropAndResizeOutputDim(const ccTensorDescriptor_t imageDesc, const ccTensorDescriptor_t boxesDesc, + const ccTensorDescriptor_t boxidxDesc, const int32_t resizeHeight, + const int32_t resizeWidth, int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief crop and resize forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] imageDesc descriptor of images + * @param [in] image input data in device memory + * @param [in] boxesDesc descriptor of boxes + * @param [in] boxes input data in device memory + * @param [in] boxidxDesc descriptor of boxidx + * @param [in] boxidx input data in device memory + * @param [in] method enum of resize method + * @param [in] extrapolationValue Value used for extrapolation, when applicable + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccCropAndResizeForward(ccHandle_t handle, const ccResizeMethod_t method, const float extrapolationValue, + const void *alpha, const ccTensorDescriptor_t imageDesc, const void *image, + const ccTensorDescriptor_t boxesDesc, const void *boxes, + const ccTensorDescriptor_t boxidxDesc, const void *boxidx, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief select forward computation + * @param [in] handle cce handle + * @param [in] alpha reserved + * @param [in] condDesc descriptor of cond tensor + * @param [in] cond cond data in device memory + * @param [in] xDesc descriptor of x tensor + * @param [in] x x data in device memory + * @param [in] yDesc descriptor of y tensor + * @param [in] y y data in device memory + * @param [in] beta reserved + * @param [in] outputDesc descriptor of output tensor + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccSelect(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t condDesc, const void *cond, + const ccTensorDescriptor_t xDesc, const void *x, const ccTensorDescriptor_t yDesc, const void *y, + const void *beta, const ccTensorDescriptor_t outDesc, void *out); + +/** + * @ingroup dnn + * @brief get the output dimension info of where + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @return ccStatus_t + */ +ccStatus_t ccGetWhereOutputDim(const ccTensorDescriptor_t xDesc, int32_t *dimCnt, int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief where forward computation + * @param [in] handle cce handle + * @param [in] alpha reserved + * @param [in] condDesc descriptor of cond tensor + * @param [in] cond cond data in device memory + * @param [in] xDesc descriptor of x tensor + * @param [in] x x data in device memory + * @param [in] yDesc descriptor of y tensor + * @param [out] y y data in device memory + * @return ccStatus_t + */ +ccStatus_t ccWhere(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief reverse forward. + * @param [in] handle cce handle + * @param [in] axis dim that need reverse + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReverseForward(ccHandle_t handle, const ccIntArray_t *axis, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief floor forward: + * data type only support float float16 + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccFloorForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief ceil forward: + * data type only support float float16 + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccCeilForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get the output dimension info of truncate mod + * @param [in] xDesc descriptor of input tensor + * @param [in] yDesc descriptor of input tensor + * @param [out] dimCnt [dim count of the output tensor] + * @param [out] dim[] [shape of the output tensor] + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetTruncatemodOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief truncate mod forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] yDesc descriptor of input tensor + * @param [in] y input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccTruncatemodForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief Spatial Pyramid Pooling + * @param [in] handle cce handle + * @param [in] alpha reserved + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] workspace temp workspace + * @param [in] workspaceSizeInBytes temp workspace size + * @param [in] pyramidHeight pyramid height + * @param [in] poolingMode pooling mode + * @param [in] beta reserved + * @param [in] outputDesc descriptor of output tensor + * @param [out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccSPPForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + void *workspace, const uint32_t workspaceSizeInBytes, const uint32_t pyramidHeight, + const ccPoolingMode_t poolingMode, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); +/** + * @ingroup dnn + * @brief Get Spatial Pyramid Pooling output dim + * @param [in] xDesc descriptor of input tensor + * @param [in] pyramidHeight pyramid height + * @param [in] dimLen length of dim + * @param [out] dimCnt output tensor dim cnt + * @param [out] dim output tensor dim + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetSPPOutputDim(const ccTensorDescriptor_t xDesc, const uint32_t pyramidHeight, int32_t *dimCnt, + int32_t dim[], const int32_t dimLen); +/** + * @ingroup dnn + * @brief Get Spatial Pyramid Pooling workspace size + * @param [in] xDesc descriptor of input tensor + * @param [in] pyramidHeight pyramid height + * @param [out] workspaceSizeInBytes workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetSPPWorkspaceSize(const ccTensorDescriptor_t xDesc, const uint32_t pyramidHeight, + uint32_t *workspaceSizeInBytes); + +/** + * @ingroup dnn + * @brief BNLL forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccBNLLForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief bias forward. + * @param [in] handle cce handle + * @param [in] axis axis + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data x + * @param [in] x input data x in device memory + * @param [in] biasDesc descriptor of input data bias + * @param [in] bias input data bias in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccBiasForward(ccHandle_t handle, const int axis, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const ccTensorDescriptor_t biasDesc, const void *bias, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief threshold forward computation + * @param [in] handle cce handle + * @param [in] threshold threshold + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccThresholdForward(ccHandle_t handle, const void *threshold, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief shufflechannel forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] group number of groups + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +// TODO AICPU: please add shufflechannel custom params and comment +ccStatus_t ccShuffleChannelForward(ccHandle_t handle, const void *alpha, uint32_t group, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief mvn forward. + * @param [in] handle cce handle + * @param [in] acrossChannel across channel. true: across, false: not + * @param [in] normalizeVariance normalizeVariance. true: normalizeVariance, false: not + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccMVNForward(ccHandle_t handle, bool acrossChannel, bool normalizeVariance, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, void *workSpace, uint32_t workSpaceSizeInBytes, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get the workspace size of mvn + * @param [in] xDesc descriptor of input data + * @param [in] acrossChannel across channel. true: across, false: not + * @param [in|out] sizeInBytes Workspace size need for whole computation + */ +ccStatus_t ccGetMVNWorkspaceSize(const ccTensorDescriptor_t xDesc, bool acrossChannel, uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief heatmap2coord forward output is hotspot value and corresponding coordinates + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] coordh calibration high + * @param [in] coordw calibration wide + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccHeatmap2coordForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + int32_t coordh, int32_t coordw, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief get the output dimension info of heatmap2coord + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetHeatmap2coordOutputDim(const ccTensorDescriptor_t xDesc, int32_t *dimCnt, int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief swish forward. + * @param [in] handle cce handle + * @param [in] scale param of swish function, y = x / (1 + sigmoid(scale * x)) + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ + +ccStatus_t ccSwishForward(ccHandle_t handle, const float scale, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +ccStatus_t ccTeForward(ccHandle_t handle, const void *stubFunc, uint32_t coreDim, const void *args, uint32_t argsSize, + const rtL2Ctrl_t *l2ctrl, int32_t inputNum, const ccTensorDescriptor_t xDesc[], const void *x[], + int32_t outputNum, const ccTensorDescriptor_t yDesc[], void *y[], bool isAiCore); + +#ifndef DAVINCI_LITE +ccStatus_t ccAiCpuCustomizeForward(ccHandle_t handle, aicpu_run_func stubFunc, opTensor_t *xOpDesc[], void *x[], + int32_t inputNum, opTensor_t *yOpDesc[], void *y[], void *op_attr_handle, + int32_t outputNum, const ccTensorDescriptor_t xDesc[], + const ccTensorDescriptor_t yDesc[], const void *op_attr_str, uint32_t op_attr_size); +#endif +/** + * @ingroup dnn + * @brief embedding lookup forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data x + * @param [in] x input data x in device memory + * @param [in] idxDesc descriptor of input data idx + * @param [in] idx input data idx in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccEmbeddingLookupForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const ccTensorDescriptor_t idxDesc, const void *idx, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup + * @brief embedding lookup forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] inputNum inputNum + * @param [in] xDesc[] descriptor array of input data x + * @param [in] x[] input data x array in device memory + * @param [in] workSpace workSpace addr + * @param [in] workSpaceSizeInBytes workSpace size + * @param [in] idxDesc descriptor of input data idx + * @param [in] idx input data idx in device memory + * @param [in] partitionStrategy partitionStrategy + * @param [in] maxNorm addr of maxNorm + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccEmbeddingLookupForward(ccHandle_t handle, const void *alpha, const int32_t inputNum, + const ccTensorDescriptor_t xDesc[], const void *x[], void *workSpace, + const uint32_t workSpaceSizeInBytes, const ccTensorDescriptor_t idxDesc, + const void *idx, ccPartitionStrategy_t partitionStrategy, const void *maxNorm, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + *[ccGetEmbeddingLookupOutputDim] + *@param inputNum [input tensor numbers] + *@param xDesc[] [input tensor descriptor] + *@param idxDesc [idx tensor descriptor] + *@param dimCnt [output dim count] + *@param dim[] [output dim] + *@param [in| dimlen length of dim + *@return ccStatus_t [status] + */ +ccStatus_t ccGetEmbeddingLookupOutputDim(const int32_t inputNum, const ccTensorDescriptor_t xDesc[], + const ccTensorDescriptor_t idxDesc, int32_t *dimCnt, int32_t dim[], + int32_t dimLen); + +/** + * @ingroup dnn + *[ccGetEmbeddingLookupWorkspaceSize] + *@param inputNum [input tensor numbers] + *@param idxDesc [input tensor descriptor] + *@param isMaxNormExist [isMaxNormExist] + *@param sizeInBytes [output size] + *@return ccStatus_t [status] + */ +ccStatus_t ccGetEmbeddingLookupWorkspaceSize(const int32_t inputNum, const ccTensorDescriptor_t idxDesc, + const bool isMaxNormExist, uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief check if it is the first layer of resnet50 and semecefc + * @param [in] tensorDesc descriptor of input tensor. + * @param [in] convDesc conv descriptor. + * @param [in] filterDesc descriptor of weight tensor. + * @return ccStatus_t + */ +ccStatus_t c04DescParamCheck(const ccTensorDescriptor_t tensorDesc, const ccConvolutionDescriptor_t convDesc, + const ccFilterDescriptor_t filterDesc); + +#ifndef DAVINCI_LITE +/** + * @ingroup dnn + * @brief convolution forward computation + * @param [in] handle cce handle + * @param [in] convDesc descriptor of convolution operator + * @param [in] alpha scaling factors + * @param [in] beta scaling factors + * @param [in] xDesc x descriptor of input tensor + * @param [in] x x data in device memory + * @param [in] dyDesc descriptor of dy + * @param [in] dy dy data in device memory + * @param [in] dwDesc descriptor of dwDesc + * @param [out] dw dw data in device memory + * @param [in] algo algorithm of convolution forward + * @param [in] workSpace temp space, maybe NULL if no need temp space + * @param [in] workSpaceSizeInBytes sizeof workspace + * @return ccStatus_t + */ +ccStatus_t ccConvolutionBackwardFilter(ccHandle_t handle, const ccConvolutionDescriptor_t convDesc, void *alpha, + void *beta, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t dyDesc, const void *dy, + const ccFilterDescriptor_t dwDesc, void *dw, ccConvolutionBwdAlgo_t algo, + void *workSpace, uint32_t workSpaceSizeInBytes); +#endif + +/** + * @ingroup dnn + * @brief get the temp space size of convolution forward computation, maybe no need temp space + * @param [in] handle cce handle + * @param [in] dyDesc descriptor of input tensor dy + * @param [in] convDesc descriptor of convolution operator + * @param [in] xDesc descriptor of input tensor + * @param [in] dwDesc descriptor of filter + * @param [in] algo algorithm of convolution forward + * @param [in|out] sizeInBytes temp space size need for specified algorithm + * @return ccStatus_t + */ +ccStatus_t ccGetConvolutionBackwardFilterWorkspaceSize(ccHandle_t handle, const ccTensorDescriptor_t dyDesc, + const ccConvolutionDescriptor_t convDesc, + const ccTensorDescriptor_t xDesc, + const ccFilterDescriptor_t dwDesc, ccConvolutionBwdAlgo_t algo, + uint32_t *sizeInBytes); + +#ifndef DAVINCI_LITE +ccStatus_t ccBatchNormalizationBackward(ccHandle_t handle, ccBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, const void *betaParamDiff, + const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t dyDesc, const void *dy, + const ccTensorDescriptor_t dxDesc, void *dx, + const ccTensorDescriptor_t bnScaleBiasDiffDesc, const void *bnScale, + void *resultBnScaleDiff, void *resultBnBiasDiff, const void *workSpace, + const uint32_t workSpaceSizeInBytes, double epsilon, const void *SaveMean, + const void *SaveInvVariance); +#endif + +ccStatus_t ccGetBatchNormalizationBackwardWorkspaceSize(ccHandle_t handle, ccBatchNormMode_t mode, + ccTensorDescriptor_t xDesc, ccTensorDescriptor_t dyDesc, + ccTensorDescriptor_t dxDesc, + ccTensorDescriptor_t bnScaleBiasDesc, uint32_t *sizeInBytes); + +#ifndef DAVINCI_LITE +ccStatus_t ccBatchNormalizationForwardTraining(ccHandle_t handle, ccBatchNormMode_t mode, const void *alpha, + const void *beta, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, void *y, + const ccTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, double exponentialAverageFactor, + void *resultRunningMean, void *resultRunningVariance, void *workSpace, + uint32_t workSpaceSizeInBytes, double epsilon, void *resultSaveMean, + void *resultSaveInvVariance, const bool isTraining); +#endif + +ccStatus_t ccGetBatchNormalizationForwardTrainingWorkspaceSize(ccHandle_t handle, ccBatchNormMode_t mode, + ccTensorDescriptor_t xDesc, ccTensorDescriptor_t yDesc, + const ccTensorDescriptor_t bnScaleBiasMeanVarDesc, + uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief generate an random normal Tensor use given on/off scale. + * @param [in] handle Stream handle. + * @param [in] alpha reserved. + * @param [in] meanDesc Mean description of one-hot position. + * @param [in] mean Data pointer of mean. + * @param [in] scaleDesc On/off scale description. + * @param [in] scale Data pointer of on/off scale. + * @param [in] seed random seed used to generate random number + * @param [in] seed2 random seed used to generate random number + * @param [in] beta reserved. + * @param [in] outputDesc Description of the generated one-hot tensor. + * @param [output] output Data pointer of output. + * @return ccStatus_t + */ +ccStatus_t ccRandomNormalForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t meanDesc, + const void *mean, const ccTensorDescriptor_t scaleDesc, const void *scale, + const int64_t seed1, const int64_t seed2, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief generate random uniform tensor. + * @param [in] handle Stream handle. + * @param [in] alpha reserved. + * @param [in] minvalDesc Mean description of one-hot position. + * @param [in] minval Data pointer of mean. + * @param [in] maxvalDesc On/off scale description. + * @param [in] maxval Data pointer of on/off scale. + * @param [in] seed random seed used to generate random number + * @param [in] seed2 random seed used to generate random number + * @param [in] beta reserved. + * @param [in] outputDesc Description of the generated one-hot tensor. + * @param [output] output Data pointer of output. + * @return ccStatus_t + */ +ccStatus_t ccRandomUniformForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t minvalDesc, + const void *minval, const ccTensorDescriptor_t maxvalDesc, const void *maxval, + const int64_t seed1, const int64_t seed2, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/**^M + * @ingroup dnn^M\r 10932 + * @brief generate BatchMatMul tensor.^M\r 10933 + * @param [in] handle Stream handle.^M\r 10934 + * @param [in] alpha reserved.^M\r 10935 + * @param [in] xDesc tensorA Desc.^M\r 10936 + * @param [in] x Data pointer of tensorA.^M\r 10937 + * @param [in] yDesc tensorB Desc.^M\r 10938 + * @param [in] y Data pointer of tensorB.^M\r 10939 + * @param [in] beta reserved.^M\r 10940 + * @param [in] adj_x tensorA transpose flag^M\r 10941 + * @param [in] adj_y tensorB transpose flag^M\r 10942 + * @param [in] outpDesc Description of the tensor output .^M\r 10943 + * @param [output] out Data pointer of output.^M\r 10944 + * @return ccStatus_t^M + */ +ccStatus_t ccBatchMatMulForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, const bool adj_x, + const bool adj_y, const ccTensorDescriptor_t outDesc, void *out); + +ccStatus_t ccGetBatchMatMulOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, bool adj_x, + bool adj_y, int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief generator conv int8 all offset factor + * @param [in] para the struct for scale and offset of input, filter and output + * @param [in|out] offsetW offset of filter + * @param [in|out] offsetPad offset of input + * @param [in|out] scaledQrq scale computing result of input , filter and output + * @param [in|out] nextoffsetq offset of output + * @return ccStatus_t + */ +ccStatus_t ccGenQuantAllOffsetFactor(const ccQuantAllOffsetPara_t *para, uint8_t &offsetW, uint8_t &offsetPad, + uint16_t &scaledQrq, uint16_t &nextoffsetq); + +/** + * @ingroup dnn + * @brief get conv int8 all offset fracZ size + * @param [in] filterDesc descriptor of filter tensor + * @param [in|out] conv int8 all offset fracZ size + * @param [in] groupNum group conv num + * @return ccStatus_t + */ +ccStatus_t ccSetGroupConvScene(const ccFilterDescriptor_t tensorDesc, ccConvolutionDescriptor_t convDesc); + +ccStatus_t ccGetInt8AllOffsetFilterFracZSizeInBytes(const ccFilterDescriptor_t filterSrcDesc, + const ccFilterDescriptor_t filterDesc, uint32_t &size, + uint32_t groupNum); + +/** + * @ingroup dnn + * @brief transform filter in conv int8 all offset scene + * @param [in] filterSrcInfo descriptor of filter tensor before fracZ transform + * @param [in] filterSrc filter addr before fracZ transform + * @param [in] filterDstInfo descriptor of filter tensor after fracZ transform + * @param [in] filterDst filter addr after fracZ transform + * @param [in] quantPara the struct for scale and offset of input, filter and output + * @param [in] ySizeInBytes filter size after fracZ transform + * @param [in|out] outputDataType output data type + * @param [in] groupNum group conv num + * @return ccStatus_t + */ +ccStatus_t ccTransFilterInt8AllOffset(ccFilterDescriptor_t filterSrcInfo, const void *filterSrc, + ccFilterDescriptor_t filterDstInfo, void *filterDst, + const ccQuantAllOffsetPara_t *quantPara, uint32_t ySizeInBytes, + ccDataType_t outputDataType, uint32_t groupNum); + +/** + * @ingroup dnn + * @brief transform bias in conv int8 all offset scene + * @param [in] filterDesc descriptor of filter tensor + * @param [in] biasDesc descriptor of bias tensor + * @param [in] quantPara the struct for scale and offset of input, filter and output + * @param [in] w filter addr + * @param [in] bias bias addr + * @return ccStatus_t + */ +ccStatus_t ccTransInt8AllOffsetBias(const ccFilterDescriptor_t filterDesc, const ccTensorDescriptor_t biasDesc, + const ccQuantAllOffsetPara_t *quantPara, const void *w, const void *bias); + +/** + * @ingroup dnn + * @get dequantize + * @param [in] handle handle id + * @param [in] alpha alpha addr + * @param [in] xDesc the input Desc descriptor + * @param [in] x x data addr + * @param [in] beta beta data addr + * @param [in] yDesc the output Desc descriptor + * @param [in] y y data addr + * @return ccStatus_t + */ +ccStatus_t ccDequantizeCoreForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const void *beta, const ccTensorDescriptor_t yDesc, void *y); +/** + * @ingroup dnn + * @get quantize + * @param [in] handle handle id + * @param [in] alpha alpha addr + * @param [in] xDesc the input Desc descriptor + * @param [in] x x data addr + * @param [in] beta beta data addr + * @param [in] yDesc the output Desc descriptor + * @param [in] y y data addr + * @return ccStatus_t + */ +ccStatus_t ccQuantizeCoreForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t yDesc, void *y); + +#ifndef DAVINCI_LITE +ccStatus_t ccActivationBackward(ccHandle_t handle, const ccActivationDescriptor_t activationDesc, const void *alpha, + const ccTensorDescriptor_t dyDesc, const void *dy, const ccTensorDescriptor_t xDesc, + const void *x, const void *beta, const ccTensorDescriptor_t dxDesc, void *dx); +#endif + +ccStatus_t ccL2LossForward(ccHandle_t handle, const ccL2LossDescriptor_t l2lossDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t yDesc, void *y); + +/** + * @ingroup dnn + * @brief get the output dimension info of top k v2 + * @param [in] xDesc descriptor of input tensor x + * @param [in] yDesc descriptor of input tensor y + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetTopKV2OutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t kDesc, const void *k, + const int64_t axis, int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief top k v2 forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor x + * @param [in] x input data x in device memory + * @param [in] yDesc descriptor of input tensor y + * @param [in] y input data y in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccTopKV2Forward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t kDesc, const void *k, const void *beta, const bool sorted, + const int64_t axis, void *workSpace, const uint32_t workSpaceSizeInBytes, + const ccTensorDescriptor_t outputValuesDesc, void *outputValues, + const ccTensorDescriptor_t outputIndicesDesc, void *outputIndices); + +/** + * @ingroup dnn + * @brief get the workspace size of top k v2 + * @param [in] xDesc descriptor of input tensor x + * @param [in] yDesc descriptor of input tensor y + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] sizeInBytes point to workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetTopKV2ForwardWorkspaceSize(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t kDesc, + const ccTensorDescriptor_t indiceDesc, const void *k, const int64_t axis, + uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief Get unsorted segment reduction output dim + * @param [in] xDesc descriptor of input tensor + * @param [in] segmentIdsDesc descriptor of input segmentIds tensor + * @param [in] segmentsNum output slice num + * @param [out] dimCnt output tensor dim cnt + * @param [out] dim output tensor dim + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetUnsortedSegmentReductionOutputDim(const ccTensorDescriptor_t xDesc, + const ccTensorDescriptor_t segmentIdsDesc, int32_t segmentsNum, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief reduce all forward computation + * @param [in] handle cce handle + * @param [in] segmentsNum output slice num + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] segmentIdsDesc descriptor of input segmentIds tensor + * @param [in] x input segmentIds data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccUnsortedSegmentSumForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const ccTensorDescriptor_t segmentIdsDesc, const void *segmentIds, + const int32_t segmentsNum, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief reverse sequence forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor x + * @param [in] x input data x in device memory + * @param [in] yDesc descriptor of input tensor y + * @param [in] y input data y in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccReverseSequenceForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t inputDesc, + const void *input, const ccTensorDescriptor_t seqLengthsDesc, + const void *seqLengths, int64_t seqAxis, int64_t batchAxis, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief realdiv between two tensors. + * @param [in] alpha reserved. + * @param [in] xDesc description of the left operator tensor. + * @param [in] x data point of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [in] y data point of the right operator tensor. + * @param [in] beta reserved. + * @param [in] outputDesc description of the output tensor. + * @param [output] output data point of the output tensor. + * @return ccStatus_t + */ + +ccStatus_t ccEqualForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get output shape of realdiv. + * @param [in] xDesc description of the left operator tensor. + * @param [in] yDesc description of the right operator tensor. + * @param [out] dimCnt output tensor dim cnt + * @param [out] dim output tensor dim + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetEqualOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t yDesc, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief invert permutation forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccInvertPermutationForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, + const void *x, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); + +/** + * @ingroup dnn + * @brief get the workspace size of non max suppression + * @param [in] handle descriptor of handle + * @param [in] scoresDesc descriptor of input tensor scoresDesc + * @param [in] boxesDesc descriptor of input tensor boxesDesc + * @param [in|out] sizeInBytes point to workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetNonMaxSuppressionWorkspaceSize(ccHandle_t handle, const ccTensorDescriptor_t scoresDesc, + const ccTensorDescriptor_t boxesDesc, uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief get the output dim of non max suppression + * @param [in] scoresDesc descriptor of input tensor scoresDesc + * @param [in] maxOutPutSize the max size of output + * @param [in|out] dimCnt point to the count of dim + * @param [in|out] dim[] the array of output dim + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetNonMaxSuppressionOutputDim(const ccTensorDescriptor_t scoresDesc, const int32_t maxOutPutSize, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief multinomial forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] numSamples number of independent samples to draw for each row slice + * @param [in] seed1 sed to create a random seed for the distribution + * @param [in] seed2 sed to create a random seed for the distribution + * @param [in] workSpace work space for inter access + * @param [in] workSpaceSizeInBytes work space size + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccMultinomialForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + int32_t numSamples, int64_t seed1, int64_t seed2, void *workSpace, + uint32_t workSpaceSizeInBytes, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); +/** + * @ingroup dnn + * @brief get output dim of generated one-hot tensor. + * @param [in] indicesDesc Indices description of one-hot position. + * @param [in] depth On/off value description. + * @param [in] axis Data pointer of on/off value. + * @param [output] dimCnt Description of the generated one-hot tensor. + * @param [output] dim Data pointer of output. + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetOneHotOutputDim(const ccTensorDescriptor_t indicesDesc, int32_t depth, int32_t axis, int32_t *dimCnt, + int32_t *dim, int32_t dimLen); + +/** + * @ingroup dnn + * @brief generate an one-hot Tensor use given on/off value. + * @param [in] handle Stream handle. + * @param [in] alpha reserved. + * @param [in] indicesDesc Indices description of one-hot position. + * @param [in] indices Data pointer of indices. + * @param [in] onDesc On value description. + * @param [in] on Data pointer of on value. + * @param [in] offDesc Off value description. + * @param [in] off Data pointer of off value. + * @param [in] depth On/off value description. + * @param [in] axis Data pointer of on/off value. + * @param [in] beta reserved. + * @param [in] outputDesc Description of the generated one-hot tensor. + * @param [output] output Data pointer of output. + * @return ccStatus_t + */ +ccStatus_t ccOneHotForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t indicesDesc, + const void *indices, const ccTensorDescriptor_t onDesc, const void *on, + const ccTensorDescriptor_t offDesc, const void *off, const int32_t depth, const int32_t axis, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief get the workspaceSize of multinomial + * @param [in] xDesc descriptor of input tensor + * @param [in] numSamples number sample + * @param [out] sizeInBytes wor space size of byte + * @return ccStatus_t + */ +ccStatus_t ccGetMultinomialWorkspaceSize(const ccTensorDescriptor_t xDesc, uint32_t *sizeInBytes); +/** + * @ingroup dnn + * @brief get the output dimension info of multinomial + * @param [in] xDesc descriptor of input tensor + * @param [in] numSample number of independent samples to draw for each row slice + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetMultinomialOutputDim(const ccTensorDescriptor_t xDesc, int32_t numSample, int32_t *dimCnt, + int32_t dim[], int32_t dimLen); +/** + * @ingroup dnn + * @brief get the output dimension info of BiasAddBackward + * @param [in] dyDesc descriptor of input tensor + * @param [in] out] n outputTensor [N]CHW + * @param [in|out] c outputTensor N[C]HW + * @param [in|out] h outputTensor NC[H]W + * @param [in|out] w outputTensor NCH[W] + * @return ccStatus_t + */ +ccStatus_t ccGetBiasAddBackwardOutputDim(const ccTensorDescriptor_t dyDesc, int32_t *n, int32_t *c, int32_t *h, + int32_t *w); + +/** + * @ingroup dnn + * @brief biasadd backward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] dyDesc descriptor of input data + * @param [in] dy input data in device memory + * @param [in] beta common scale factor + * @param [in] dbDesc descriptor of output data + * @param [in|out] db output data in device memory + * @return ccStatus_t + */ +#ifndef DAVINCI_LITE +ccStatus_t ccBiasAddBackward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t dyDesc, const void *dy, + const void *beta, const ccTensorDescriptor_t dbDesc, void *db); + +ccStatus_t ccMaxPoolWithArgmaxForward(ccHandle_t handle, const ccPoolingDescriptor_t poolingDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t yDesc, void *y, const ccTensorDescriptor_t argMaskDesc, + void *argMask); +#endif + +ccStatus_t ccCreatePoolingMaskDescriptor(ccTensorDescriptor_t *poolingMaskDesc); + +ccStatus_t ccDestroyPoolingMaskDescriptor(ccTensorDescriptor_t *poolingMaskDesc); + +ccStatus_t ccSetPoolingMaskTensorDescriptor(ccTensorDescriptor_t poolingMaskDesc, ccTensorFormat_t format, + ccDataType_t dataType, int32_t n, int32_t c, int32_t h, int32_t w, + int32_t windowH, int32_t windowW); + +ccStatus_t ccGetPoolingMaskTensorSizeInBytes(ccTensorDescriptor_t poolingMaskDesc, uint32_t *size); + +/** + * @ingroup dnn + * @brief get the mask output dimension info of maxpooling training forward + * @param [in] pooling descriptor of convolution operator + * @param [in] xDesc descriptor of input tensor + * @param [in|out] n point to batch size + * @param [in|out] c point to channels + * @param [in|out] h point to height of feature map + * @param [in|out] w point to width of feature map + * @param [in|out] windowH point to height of window + * @param [in|out] windowW point to width of windowW + * @return ccStatus_t + */ +ccStatus_t ccGetPoolingMaskDim(const ccPoolingDescriptor_t poolingDesc, const ccTensorDescriptor_t xDesc, int32_t *n, + int32_t *c, int32_t *h, int32_t *w, int32_t *windowH, int32_t *windowW); + +#ifndef DAVINCI_LITE +ccStatus_t ccSoftmaxCrossEntropyLoss(ccHandle_t handle, ccSoftmaxAlgo_t algo, ccSoftmaxMode_t mode, + ccCrossEntropyMode_t ceMode, const void *alpha, const void *scale, + const ccTensorDescriptor_t logitsDesc, const void *logits, + const ccTensorDescriptor_t labelsDesc, const void *labels, const void *labelSmooth, + const void *beta, const ccTensorDescriptor_t lossDesc, void *loss); + +ccStatus_t ccSoftmaxCrossEntropyDx(ccHandle_t handle, ccSoftmaxAlgo_t algo, ccSoftmaxMode_t mode, + ccCrossEntropyMode_t ceMode, const void *alpha, const void *scale, + const ccTensorDescriptor_t logitsDesc, const void *logits, + const ccTensorDescriptor_t labelsDesc, const void *labels, const void *labelSmooth, + const void *beta, const ccTensorDescriptor_t dxDesc, void *dx); + +ccStatus_t ccAvgPoolingBackward(ccHandle_t handle, const ccPoolingDescriptor_t poolingDesc, const void *alpha, + const ccTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const ccTensorDescriptor_t dxDesc, const void *dx); + +ccStatus_t ccTrainingAssignOp(ccHandle_t handle, const ccAssignOpMode_t assignOpDesc, const void *alpha, + const void *beta, const ccTensorDescriptor_t aDesc, void *a, + const ccTensorDescriptor_t bDesc, const void *b); + +/** + * @ingroup dnn + * @brief momentum optimizer for variable update + * @param [in] handle cce handle + * @param [in] inputDesc descriptor of input tensor: gradient,accumulation,variable + * @param [in] gradient gradient input + * @param [in|out] accumulation accumulation input and updated output + * @param [in|out] variable variable input and updated output + * @param [in] algo indicate whether need FP16 output + * @param [in] momentum scaler to control accumulation + * @param [in] learningRate scaler + * @param [in] lossScaleReciprocal scaler + * @param [in] workSpace additional memory address + * @param [in] workSpaceSizeInBytes additional memory size + * @param [out] variableUpdatedFP16Desc descriptor of FP16 output tensor: variableUpdatedFP16 + * @param [out] variableUpdatedFP16 variableUpdatedFP16 + * @return ccStatus_t + */ +ccStatus_t ccApplyMomentum(ccHandle_t handle, const ccTensorDescriptor_t inputDesc, const void *gradient, + void *accumulation, void *variable, const ccMomentumAlgo_t algo, const void *momentum, + const void *learningRate, const void *lossScaleReciprocal, void *workSpace, + const uint32_t workSpaceSizeInBytes, const ccTensorDescriptor_t variableUpdatedFP16Desc, + void *variableUpdatedFP16); + +ccStatus_t ccSsdClassifyLossTrain(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t labelDesc, + const void *label, const ccTensorDescriptor_t greaterConstDesc, + const void *greaterConst, const ccTensorDescriptor_t subConstDesc, + const void *subConst, const ccTensorDescriptor_t sparseDesc, const void *sparse, + const void *beta, const ccTensorDescriptor_t castoutDesc, const void *castout, + const ccTensorDescriptor_t muloutDesc, const void *mulout); + +#endif + +/** + * @ingroup dnn + * @brief get the workspace size of applymomentum + * @param [in] inputDesc descriptor of input tensor + * @return ccStatus_t + */ +ccStatus_t ccGetApplyMomentumWorkspaceSize(const ccTensorDescriptor_t inputDesc, uint32_t *sizeInBytes); +#ifndef DAVINCI_LITE +ccStatus_t ccHwck2FracZ(ccHandle_t handle, const ccFilterDescriptor_t xDesc, const void *x, + const ccFilterDescriptor_t yDesc, void *y); + +ccStatus_t ccFracZ2Hwck(ccHandle_t handle, const ccFilterDescriptor_t xDesc, const void *x, + const ccFilterDescriptor_t yDesc, void *y); +ccStatus_t ccAddNForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const int32_t inputNum, + const void *x[], const void *beta, void *workSpace, uint32_t workSpaceSizeInBytes, + const ccTensorDescriptor_t yDesc, void *y); +#endif +ccStatus_t ccGetAddNForwardWorkspaceSize(ccHandle_t handle, const ccTensorDescriptor_t xDesc, const int32_t inputNum, + const ccTensorDescriptor_t yDesc, uint32_t *sizeInBytes); +ccStatus_t ccGetAddNForwardOutputDim(const ccTensorDescriptor_t xDesc, int32_t *dimCnt, int32_t *dim, int32_t dimLen); +ccStatus_t ccAddTrainForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t wDesc, const void *w, const void *beta, void *workSpace, + uint32_t workSpaceSizeInBytes, const ccTensorDescriptor_t yDesc, void *y); +ccStatus_t ccGetAddTrainForwardWorkspaceSize(ccHandle_t handle, const ccTensorDescriptor_t xDesc, + const ccTensorDescriptor_t wDesc, const ccTensorDescriptor_t yDesc, + uint32_t *sizeInBytes); +ccStatus_t ccGetAddTrainForwardOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t wDesc, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); +ccStatus_t ccMulTrainForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t wDesc, const void *w, const void *beta, void *workSpace, + uint32_t workSpaceSizeInBytes, const ccTensorDescriptor_t yDesc, void *y); +ccStatus_t ccGetMulTrainForwardWorkspaceSize(ccHandle_t handle, const ccTensorDescriptor_t xDesc, + const ccTensorDescriptor_t wDesc, const ccTensorDescriptor_t yDesc, + uint32_t *sizeInBytes); +ccStatus_t ccGetMulTrainForwardOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t wDesc, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief get workspace size + * @param [in] xDesc descriptor of input tensor + * @param [in|out] sizeInBytes workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetRandomShuffleWorkspaceSize(const ccTensorDescriptor_t xDesc, uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief random shuffle forward computation + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] workspace temporary space + * @param [in] workspaceSizeInBytes temporary space size + * @param [in] seed random seed used to generate random number + * @param [in] seed2 random seed used to generate random number + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccRandomShuffleForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + void *workspace, const uint32_t workspaceSizeInBytes, const int64_t seed1, + const int64_t seed2, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); +/** + * @ingroup dnn + * @brief sin forward: + * data type only support float float16 double + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] input input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccSinForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *input, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief cos forward: + * data type only support float float16 double + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] input input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccCosForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *input, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief tan forward: + * data type only support float float16 double + * data format only support ND + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] input input data in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccTanForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *input, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief get the output dimension info of unstack + * @param [in] xDesc descriptor of input tensor + * @param [in] axis the axis to unstack along + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetUnstackOutputDim(const ccTensorDescriptor_t xDesc, int32_t axis, int32_t *dimCnt, int32_t dim[], + int32_t dimLen); + +/** + * @ingroup dnn + * @brief unstack forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data + * @param [in] x input data in device memory + * @param [in] num the length of the dimension axis + * @param [in] axis the axis to unstack along + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ + +ccStatus_t ccUnstackForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + int32_t num, int32_t axis, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output[]); + +ccStatus_t ccResizeNearestNeighborCpuForward(ccHandle_t handle, const ccResizeNearestNeighborDescriptor_t resizeDesc, + const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief get the output dimension info of resize nearest neighbor + * @param [in] resizeDesc descriptor of resize + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetResizeNearestNeighborOutputDim(const ccResizeNearestNeighborDescriptor_t resizeDesc, + const ccTensorDescriptor_t xDesc, int32_t *dimCnt, int32_t dim[], + int32_t dimLen); + +/** + * @ingroup dnn + * @brief create descriptor of ResizeNearestNeighbor + * @param [in|out] resizeDesc point to descriptor of ResizeNearestNeighbor attr + * @return ccStatus_t + */ +ccStatus_t ccCreateResizeNearestNeighborDescriptor(ccResizeNearestNeighborDescriptor_t *resizeDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of ResizeNearestNeighbor + * @param [in|out] resizeDesc point to descriptor of ResizeNearestNeighbor attr + * @return ccStatus_t + */ +ccStatus_t ccDestroyResizeNearestNeighborDescriptor(ccResizeNearestNeighborDescriptor_t *resizeDesc); + +/** + * @ingroup dnn + * @brief set descriptor of ResizeNearestNeighbor. + * @param [in|out] resizeDesc descriptor of resize nearest neighbor operator + * @param [in] alignCorners whether the centers of input and output are aligned + * @param [in] height height of output + * @param [in] width width of output + * @return ccStatus_t + */ +ccStatus_t ccSetResizeNearestNeighborDescriptor(ccResizeNearestNeighborDescriptor_t resizeDesc, bool alignCorners, + int32_t height, int32_t width); + +/** + * @ingroup dnn + * [ccGetPadV2OutputDim] + * @brief get the output dimension info of pad + * @param [in] xDesc descriptor of input tensor x + * @param [in] padDesc descriptor of input paddings + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetPadV2OutputDim(const ccTensorDescriptor_t xDesc, const ccPadV2Descriptor_t padDesc, int32_t *dimCnt, + int32_t dim[], int32_t dimLen); + +ccStatus_t ccPadV2CpuForward(ccHandle_t handle, const ccPadV2Descriptor_t padDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief create descriptor of parameters for padv2 function + * @param [in] point to descriptor of parameters for padv2 function + * @return ccStatus_t + */ +ccStatus_t ccCreatePadV2Descriptor(ccPadV2Descriptor_t *padDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of parameters for padv2 function + * @param [in] point to descriptor of parameters for padv2 function + * @return ccStatus_t + */ +ccStatus_t ccDestroyPadV2Descriptor(ccPadV2Descriptor_t *padDesc); + +/** + * @brief init descriptor for parameter of padv2 function + * @param [in|out] padDesc descriptor of pad + * @param [in] padShapeCnt padshape count + * @param [in] padShapeLow padshape low + * @param [in] padShapeHigh padshape high + * @param [in] padMode pad mode + * @param [in] padValue pad value ptr + * @param [in] padValueType pad value data type + * @return ccStatus_t + */ +ccStatus_t ccSetPadV2Descriptor(ccPadV2Descriptor_t padDesc, const int32_t padShapeCnt, const int32_t padShapeLow[], + const int32_t padShapeHigh[], const ccPadMode_t padMode, const void *padValue, + const ccDataType_t padValueType); +/** + * @ingroup dnn + * @brief create descriptor of batchToSpace + * @param [in|out] batchToSpaceDesc point to descriptor of batchToSpace + * @return ccStatus_t + */ +ccStatus_t ccCreateBatchToSpaceDescriptor(ccBatchToSpaceDescriptor_t *batchToSpaceDesc); + +/** + * @ingroup dnn + * @brief set batchToSpaceDesc + * @param [in|out] batchToSpaceDesc descriptor of batchToSpace + * @param [in] blockShape blockShape of batchToSpace + * @param [in] crops crops of batchToSpace + * @param [in] blockShapeLength blockShapeLength of batchToSpace + * @return ccStatus_t + */ +ccStatus_t ccSetBatchToSpaceDescriptor(ccBatchToSpaceDescriptor_t paramsDesc, const int32_t *blockShape, + const int32_t *crops, const int32_t blockShapeLength); + +/** + * @ingroup dnn + * @brief get batchToSpaceDesc + * @param [in|out] batchToSpaceDesc descriptor of batchToSpace + * @param [in] blockShape blockShape of batchToSpace + * @param [in] crops crops of batchToSpace + * @param [in] blockShapeLength blockShapeLength of batchToSpace + * @return ccStatus_t + */ +ccStatus_t ccGetBatchToSpaceDescriptor(const ccBatchToSpaceDescriptor_t paramsDesc, int32_t *blockShape, int32_t *crops, + int32_t *blockShapeLength); + +/** + * @ingroup dnn + * @brief destroy descriptor of batchToSpace + * @param [in] *batchToSpaceDesc descriptor of batchToSpace + * @return ccStatus_t + */ +ccStatus_t ccDestroyBatchToSpaceDescriptor(ccBatchToSpaceDescriptor_t *batchToSpaceDesc); + +/** + * @ingroup dnn + * @brief get the output dimension info of batch to space + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ + +ccStatus_t ccGetBatchToSpaceOutputDim(const ccTensorDescriptor_t xDesc, + const ccBatchToSpaceDescriptor_t batchToSpaceDesc, int32_t *dimCnt, int32_t dim[], + int32_t dimLen); + +/** + * @ingroup dnn + * @brief batch to space forward computation + * @param [in] handle cce handle + * @param [in] paramsDesc descriptor of input params + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ + +ccStatus_t ccBatchToSpaceForward(ccHandle_t handle, const ccBatchToSpaceDescriptor_t paramsDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief create descriptor of spaceToBatch + * @param [in|out] spaceToBatchDesc point to descriptor of spaceToBatch + * @return ccStatus_t + */ +ccStatus_t ccCreateSpaceToBatchDescriptor(ccSpaceToBatchDescriptor_t *spaceToBatchDesc); + +/** + * @ingroup dnn + * @brief set spaceToBatchDesc + * @param [in|out] spaceToBatchDesc descriptor of spaceToBatch + * @param [in] blockShape blockShape of spaceToBatch + * @param [in] paddings paddings of spaceToBatch + * @param [in] blockShapeLength blockShapeLength of spaceToBatch + * @return ccStatus_t + */ +ccStatus_t ccSetSpaceToBatchDescriptor(ccSpaceToBatchDescriptor_t paramsDesc, const int32_t *blockShape, + const int32_t *paddings, const int32_t blockShapeLength); + +/** + * @ingroup dnn + * @brief get spaceToBatchDesc + * @param [in|out] spaceToBatchDesc descriptor of spaceToBatch + * @param [in] blockShape blockShape of spaceToBatch + * @param [in] paddings paddings of spaceToBatch + * @param [in] blockShapeLength blockShapeLength of spaceToBatch + * @return ccStatus_t + */ +ccStatus_t ccGetSpaceToBatchDescriptor(const ccSpaceToBatchDescriptor_t paramsDesc, int32_t *blockShape, + int32_t *paddings, int32_t *blockShapeLength); + +/** + * @ingroup dnn + * @brief destroy descriptor of spaceToBatch + * @param [in] *spaceToBatchDesc descriptor of spaceToBatch + * @return ccStatus_t + */ +ccStatus_t ccDestroySpaceToBatchDescriptor(ccSpaceToBatchDescriptor_t *spaceToBatchDesc); + +/** + * @ingroup dnn + * @brief get the output dimension info of space to batch + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in| dimlen length of dim + * @return ccStatus_t + */ + +ccStatus_t ccGetSpaceToBatchOutputDim(const ccTensorDescriptor_t xDesc, + const ccSpaceToBatchDescriptor_t spaceToBatchDesc, int32_t *dimCnt, int32_t dim[], + int32_t dimLen); + +/** + * @ingroup dnn + * @brief space to batch forward computation + * @param [in] handle cce handle + * @param [in] paramsDesc descriptor of input params + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ + +ccStatus_t ccSpaceToBatchForward(ccHandle_t handle, const ccSpaceToBatchDescriptor_t paramsDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +ccStatus_t ccTransFilterDesc2TensorDesc(ccFilterDescriptor_t wDesc, ccTensorDescriptor_t tensorDesc); + +/* + * @brief get the output dimension info of extractImagePatches + * @param [in] xDesc descriptor of input tensor x + * @param [in] ksizes ksizes array + * @param [in] strides strides array + * @param [in] rates rates array + * @param [in] padding padding type + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @return ccStatus_t + */ +ccStatus_t ccGetExtractImagePatchesOutputDim(const ccTensorDescriptor_t xDesc, const ccIntArray_t *ksizes, + const ccIntArray_t *strides, const ccIntArray_t *rates, + const ccExtractImagePatchesPadType_t padding, int32_t *dimCnt, + int32_t dim[], const int32_t dimLen); + +/** + * @ingroup dnn + * @brief cum forward. + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data, dimCnt:1~8 + * @param [in] x input data in device memory + * @param [in] axisDesc scale factor, dimCnt:0 + * @param [in] axis which axis to cum calc, device memory + * @param [in] beta common scale factor + * @param [in] opType calc type, eg. sum, prod.... + * @param [in] exclusive cum flag, true or false + * @param [in] reverse cum flag, true or false + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccCumForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t axisDesc, const void *axis, const void *beta, const CumOpType opType, + const bool exclusive, const bool reverse, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @ingroup dnn + * @brief ExtractImagePatches forward. + * @param [in] handle cce handle + * @param [in] ksizes ksizes array + * @param [in] strides strides array + * @param [in] rates rates array + * @param [in] padding padding type + * @param [in] alpha common scale factor + * @param [in] xDesc descriptor of input data x + * @param [in] x input data x in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccExtractImagePatchesForward(ccHandle_t handle, const ccIntArray_t *ksizes, const ccIntArray_t *strides, + const ccIntArray_t *rates, const ccExtractImagePatchesPadType_t padding, + const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const void *beta, const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @brief get argmax output dim info + * @param [in] argDesc argmaxmin descriptor + * @param [in] xDesc descriptor of input tensor + * @param [in|out] dimCnt output dim count + * @param [in|out] dim output dim + * @param [in| dimlen length of dim + * @return ccStatus_t + */ +ccStatus_t ccGetArgMaxOutputDim(const ccArgmaxminDescriptor_t argDesc, const ccTensorDescriptor_t xDesc, + int32_t *dimCnt, int32_t dim[], int32_t dimLen); + +/** + * @ingroup dnn + * @brief argmax forward computation + * @param [in] handle cce handle + * @param [in] argDesc argmaxmin descriptor + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] workSpace workspace pointer + * @param [in] workSpaceSizeInBytes workspace size in bytes + * @param [in] beta bias factors + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccArgMaxForward(ccHandle_t handle, const ccArgmaxminDescriptor_t argDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, void *workSpace, + const uint32_t workSpaceSizeInBytes, const void *beta, const ccTensorDescriptor_t outputDesc, + void *output); + +/** + * @ingroup dnn + * @brief get the output dimension info of argmaxmin + * @param [in] argDesc descriptor of tagCcArgmaxmin + * @param [in] xDesc descriptor of input tensor + * @param [in|out] sizeInBytes workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetArgMaxWorkspaceSize(const ccArgmaxminDescriptor_t argDesc, const ccTensorDescriptor_t xDesc, + uint32_t *sizeInBytes); + +/** + * @ingroup dnn + * @brief create descriptor of Argmaxmin + * @param [in|out] resizeDesc point to descriptor of Argmaxmin attr + * @return ccStatus_t + */ +ccStatus_t ccCreateArgmaxminDescriptor(ccArgmaxminDescriptor_t *argDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of Interp + * @param [in|out] resizeDesc point to descriptor of Argmaxmin attr + * @return ccStatus_t + */ +ccStatus_t ccDestroyArgmaxminDescriptor(ccArgmaxminDescriptor_t *argDesc); + +/** + * @ingroup dnn + * @brief destroy descriptor of Interp + * @param [in|out] argDesc descriptor of tagCcArgmaxmin + * @param [in] axisType + * @param [in] outMaxVal whether to return the maximum value + * @param [in] topK number that returns the maximum index or maximum value + * @param [in] axis Describes which axis of the input Tensor to reduce across + * @param [in] keepDims whether to keep reduced dim + * @param [in] reduceSize the num of elements to be reduce to get topK elements, reduceSize=-1 means the total num + * of elements in axis dimension + * @param [in] reduceStride the stride for reduce operation, reduceStride=1 means the layout of target data is + * continuous + * @return ccStatus_t + */ +ccStatus_t ccSetArgmaxminDescriptor(ccArgmaxminDescriptor_t argDesc, int32_t axisType, bool outMaxVal, int64_t topK, + int64_t axis, bool keepDims, int64_t reduceSize = -1, int64_t reduceDStride = 1); + +ccStatus_t ccArgMinForward(ccHandle_t handle, const ccArgmaxminDescriptor_t argDesc, const void *alpha, + const ccTensorDescriptor_t xDesc, const void *x, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +ccStatus_t ccGetArgMinOutputDim(const ccArgmaxminDescriptor_t argDesc, const ccTensorDescriptor_t xDesc, + int32_t *dimCnt, int32_t dim[], const int32_t dimLen); +/** + * @ingroup dnn + * @brief lsh projection forward computation + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] hashDesc descriptor of input tensor hashDesc + * @param [in] hash input data hash in device memory + * @param [in] weightDesc descriptor of input tensor weightDesc + * @param [in] weight input data weight in device memory + * @param [in] inputDesc descriptor of input tensor inputDesc + * @param [in] lookup input data lookup in device memory + * @param [in] type 1:SPARSE 2.DENSE + * @param [in] beta bias factors + * @param [in] workSpace workSpace data in device memory + * @param [in] workSpaceSizeInBytes workSpace length + * @param [in] outputDesc descriptor of output tensor + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccLshProjectionForward(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t hashDesc, + const void *hash, const ccTensorDescriptor_t weightDesc, const void *weight, + const ccTensorDescriptor_t inputDesc, const void *input, const LSHProjectionType type, + const void *beta, void *workSpace, const uint32_t workSpaceSizeInBytes, + const ccTensorDescriptor_t outputDesc, void *output); +/** + * @ingroup dnn + * @brief get the workspace size of lsh projection + * @param [in] inputDesc descriptor of input tensor input + * @param [in] hashDataType data type of hash + * @param [in|out] sizeInBytes workspace size + * @return ccStatus_t + */ +ccStatus_t ccGetLshProjectionForwardWorkspaceSize(const ccTensorDescriptor_t inputDesc, const ccDataType_t hashDataType, + uint32_t *sizeInBytes); +/** + * @ingroup dnn + * @brief get the output dimension info of LshProjection, + * @param [in] hashDesc descriptor of hash + * @param [in] type type of mode + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in] dimLen dim length + * @return ccStatus_t + */ +ccStatus_t ccGetLshProjectionOutputDim(const ccTensorDescriptor_t hashDesc, const LSHProjectionType type, + int32_t *dimCnt, int32_t dim[], const int32_t dimLen); +/** + * @ingroup dnn + * @brief get the weight dimension info of LshProjection, + * @param [in] inputDesc descriptor of input + * @param [in|out] dimCnt point to the weight dimCnt + * @param [in|out] dim arrays to save dims + * @param [in] dimLen dim length + * @return ccStatus_t + */ +ccStatus_t ccGetLshProjectionWeightDim(const ccTensorDescriptor_t inputDesc, int32_t *dimCnt, int32_t dim[], + const int32_t dimLen); + +/** + * @ingroup dnn + * @brief init descriptor for parameter of upsample function + * @param [in] handle cce handle + * @param [in] upsamplePara input para in host memory + * @param [in] alpha common scale factor + * @param [in] bottomDesc descriptor of input data bottomDesc + * @param [in] bottom input data bottom in device memory + * @param [in] bottomMaskDesc descriptor of input data bottomMaskDesc + * @param [in] bottomMask input data bottomMask in device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor of output data + * @param [in|out] output output data in device memory + * @return ccStatus_t + */ +ccStatus_t ccUpsampleForward(ccHandle_t handle, const ccUpsampleParaDescriptor_t upsamplePara, const void *alpha, + const ccTensorDescriptor_t bottomDesc, const void *bottom, + const ccTensorDescriptor_t bottomMaskDesc, const void *bottomMask, const void *beta, + const ccTensorDescriptor_t outputDesc, void *output); + +/** + * @brief creat descriptor for parameter of usample function + * @param [in|out] upsampleDesc descriptor of upsamplepara + * @return ccStatus_t + */ +ccStatus_t ccCreateUpsampleDescriptor(ccUpsampleParaDescriptor_t *upsampleDesc); + +/** + * @brief destroy descriptor for parameter of upsample function + * @param [in|out] upsampleDesc descriptor of upsamplepara + * @return ccStatus_t + */ +ccStatus_t ccDestroyUpsampleDescriptor(ccUpsampleParaDescriptor_t *upsampleDesc); + +/** + * @brief set descriptor for parameter of upsample function + * @param [in|out] upsampleDesc descriptor of upsamplepara + * @param [in] scale the scale of height and width + * @param [in] scaleHeight the scale of height + * @param [in] scaleWidth the scale of Width + * @param [in] upsampleHeight the height of output + * @param [in] upsampleWidth the width of output + * @param [in] padOutHeight pad value height + * @param [in] padOutWidth pad value width + * @return ccStatus_t + */ +ccStatus_t ccSetUpsampleDescriptor(ccUpsampleParaDescriptor_t upsampleDesc, const int32_t scale, + const int32_t scaleHeight, const int32_t scaleWidth, const int32_t upsampleHeight, + const int32_t upsampleWidth, const bool padOutHeight, const bool padOutWidth); +/** + * @ingroup dnn + * @brief get the output dimension info of upsample + * @param [in] upsamplePara para of upsample + * @param [in] bottomDesc descriptor of input bottom tensor + * @param [in|out] dimCnt point to the output dimCnt + * @param [in|out] dim arrays to save dims + * @param [in] dimLen the len of dim array + * @return ccStatus_t + */ +ccStatus_t ccGetUpsampleOutputDim(const ccUpsampleParaDescriptor_t upsamplePara, const ccTensorDescriptor_t bottomDesc, + int32_t *dimCnt, int32_t dim[], const int32_t dimLen); + +#ifndef DAVINCI_LITE +ccStatus_t ccMatmul(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t wDesc, const void *w, const ccTensorDescriptor_t biasDesc, + const void *bias, const ccFullConnectFwdAlgo_t algo, void *workSpace, + const uint32_t workSpaceSizeInBytes, const void *beta, const ccTensorDescriptor_t yDesc, void *y, + const bool transposeA, const bool transposeB); +ccStatus_t ccGetMatmulOutputDim(const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t wDesc, int32_t *n, + int32_t *c, int32_t *h, int32_t *w, bool transposeA, bool transposeB); +ccStatus_t ccGetMatmulWorkspaceSize(ccHandle_t handle, const ccFullConnectFwdAlgo_t algo, + const ccTensorDescriptor_t xDesc, const ccTensorDescriptor_t wDesc, + const ccTensorDescriptor_t yDesc, uint32_t *sizeInBytes, bool transposeA, + bool transposeB); +#endif + +/** + * @ingroup dnn + * @brief gather_v2 function + * @param [in] handle cce handle + * @param [in] alpha common scale factor + * @param [in] paramsDesc descriptor + * @param [in] params device memory + * @param [in] indicesDesc descriptor + * @param [in] indices device memory + * @param [in] axisDesc descriptor + * @param [in] axis device memory + * @param [in] beta common scale factor + * @param [in] outputDesc descriptor + * @param [in|out] output device memory + * @return ccStatus_t + */ +ccStatus_t ccGatherV2(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t paramsDesc, const void *params, + const ccTensorDescriptor_t indicesDesc, const void *indices, const ccTensorDescriptor_t axisDesc, + const void *axis, const void *beta, const ccTensorDescriptor_t outputDesc, const void *output); + +/** + * @ingroup dnn + * @brief memory_clear function + * @param [in] handle cce handle + * @param [in] addrSpaceSizeInBytes addr space size + * @param [in|out] addr device memory + * @return ccStatus_t + */ +ccStatus_t ccMemoryClear(ccHandle_t handle, const uint64_t addrSpaceSizeInBytes, const void *addr); + +/** + * @ingroup dnn + * @brief check input is overflow + * @param [in] handle cce handle + * @param [in] alpha scaling factors + * @param [in] xDesc descriptor of input tensor + * @param [in] x input data in device memory + * @param [in] yDesc descriptor of output tensor + * @param [in|out] y output data in device memory + * @param [in] beta scaling factors + * @return ccStatus_t + */ +ccStatus_t ccIsFinite(ccHandle_t handle, const void *alpha, const ccTensorDescriptor_t xDesc, const void *x, + const ccTensorDescriptor_t yDesc, const void *y, const void *beta); +}; // namespace cce + +#endif // DNN_OP_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/dnn_struct.hpp b/metadef/third_party/fwkacllib/inc/cce/dnn_struct.hpp new file mode 100644 index 00000000..96566074 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/dnn_struct.hpp @@ -0,0 +1,23 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DNN_STRUCT_HPP__ +#define DNN_STRUCT_HPP__ + +#include "dnn.h" +#include "dnn_struct_base.hpp" + +#endif // DNN_STRUCT_HPP__ diff --git a/metadef/third_party/fwkacllib/inc/cce/dnn_struct_base.hpp b/metadef/third_party/fwkacllib/inc/cce/dnn_struct_base.hpp new file mode 100644 index 00000000..dd75e9ea --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/dnn_struct_base.hpp @@ -0,0 +1,894 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DNN_STRUCT_BASE_HPP__ +#define DNN_STRUCT_BASE_HPP__ + +#include "cce/cce_def.hpp" + +namespace cce { + +/** + * @ingroup dnn + * @brief max number of dimensions + */ +#define CC_DIM_MAX (8) + +/** + * @ingroup dnn + * @brief max number of dimensions when use NC1HWC0 format + */ +#define CC_REALDIM_MAX (4) + +/** + * @ingroup dnn + * @brief max input count of MscnnBoxOutput + */ +#define CC_MAX_INPUT_CNT (10) + +/** + * @ingroup dnn + * @brief image dimensions of aipp input + */ +#define CC_AIPP_IMG_DIM (2) + +/** + * @ingroup dnn + * @brief image channel number of aipp input + */ +#define CC_AIPP_IMG_CHN_NUM (4) + +/** + * @ingroup dnn + * @brief element number of aipp color space convertion matrix + */ +#define CC_AIPP_CSC_MATRIX_DIM (9) + +/** + * @ingroup dnn + * @brief element number of aipp color space convertion bias + */ +#define CC_AIPP_CSC_BIAS_DIM (3) + +/** + * @ingroup dnn + * @brief parameter number of op exp/log/pow + */ +#define PARAM_CNT_THREE (3) + +/** + * @ingroup dnn + * @brief parameter number of op nonmaxsuppression + */ +#define PARAM_CNT_TWO (2) +#define DIMCNT_NUMBER_ONE (1) +#define DIMCNT_NUMBER_TWO (2) +#define DIMCNT_NUMBER_FOUR (4) + +#define COMMON_FORMAT_NCHW_N_INDEX (0) +#define COMMON_FORMAT_NCHW_C_INDEX (1) +#define COMMON_FORMAT_NCHW_H_INDEX (2) +#define COMMON_FORMAT_NCHW_W_INDEX (3) + +/** + * @ingroup dnn + * @brief parameter number of op upsample + */ +#define UPSAMPLE_SCAL_DEFAULT_TWO (2) +#define UPSAMPLE_ILLEGAL_VALUE_1 (1) + +/** + * @ingroup dnn + * @brief struct define of StridedSlice required params. + */ + +typedef struct tagCcStridedSlice { + uint32_t dimCnt; + int32_t begin[CC_DIM_MAX]; + int32_t end[CC_DIM_MAX]; + int32_t strides[CC_DIM_MAX]; +} ccStridedSlice_t; + +/** + * @ingroup dnn + * @brief struct define of Strided_slice attrs + */ +typedef struct tagCcStridedSliceAttrs { + uint32_t beginMask; + uint32_t endMask; + uint32_t ellipsisMask; + uint32_t newAxisMask; + uint32_t shrinkAxisMask; +} ccStridedSliceAttrs_t; + +/** + * @ingroup dnn + * @brief params of batchToSpace + */ +typedef struct tagCcBatchToSpace { + int32_t blockShapeLength; + int32_t blockShape[CC_DIM_MAX]; + int32_t crops[2 * CC_DIM_MAX]; +} ccBatchToSpace_t; + +/** + * @ingroup dnn + * @brief params of spaceToBatch + */ +typedef struct tagCcSpaceToBatch { + int32_t blockShapeLength; + int32_t blockShape[CC_DIM_MAX]; + int32_t paddings[2 * CC_DIM_MAX]; +} ccSpaceToBatch_t; + +/** + * @ingroup dnn + * @brief struct define of tensor + */ +typedef struct tagCcTensor { + ccTensorFormat_t format; + ccDataType_t dataType; + int32_t dimCnt; + int32_t realDimCnt; + uint32_t dataSize; + int32_t dim[CC_DIM_MAX]; + int32_t stride[CC_DIM_MAX]; + ccVecQuantizePara_t vecQuantizePara; +} ccTensor_t; + +/** + * @ingroup dnn + * @brief struct define of filter tensor + */ +typedef struct tagCcFilter { + ccTensorFormat_t format; + ccDataType_t dataType; + int32_t dimCnt; + uint32_t dataSize; + int32_t dim[CC_DIM_MAX]; +} ccFilter_t; + +/** + * @ingroup dnn + * @brief struct define of convolution operator + */ +typedef struct tagCcConvolution { + ccConvolutionMode_t mode; + ccPaddingMode_t padMode; + int32_t dimCnt; + int32_t padding[2 * (CC_DIM_MAX - 2)]; + int32_t filterStride[CC_DIM_MAX - 2]; + int32_t dilation[CC_DIM_MAX - 2]; + int32_t group; + ccQuantizeDescriptor_t quantInfo; + ccConvolutionAipp_t aippInfo; + int32_t adj[CC_DIM_MAX - 2]; + int32_t targetShape[CC_DIM_MAX - 2]; + int32_t beforePadding[2 * (CC_DIM_MAX - 2)]; // pad before conv + uint32_t reluFlag; + int64_t concatBatchSize; +} ccConvolution_t; + +#define ccCorrelation_t ccConvolution_t +typedef struct tagCcFullConnection_t { + ccQuantizeDescriptor_t quantInfo; + uint32_t infoTabSize; + const void *infoTab; + bool reluFlag; + ccFullConnectFwdAlgo_t algo; +} ccFullConnection_t; + +typedef struct tagCcConcatFour2Five_t { + uint32_t branchNum; // how many branch for box or class + uint32_t classNum; // box branch's classNum is four, class branch's classNum is class number +} ccConcatFour2Five_t; + +typedef struct tagCcTransdata_t { + uint64_t scaleQAddr; + uint8_t scaleQValueMode; + uint64_t offsetQAddr; + uint8_t quantAlgo; + uint8_t quantize8bitFlag; +} ccTransdata_t; +/** + * @ingroup dnn + * @brief struct define of pooling operator + */ +typedef struct tagCcPooling { + ccPoolingMode_t mode; + ccPaddingMode_t padMode; + ccNanPropagation_t maxpoolingNanOpt; + int32_t dimCnt; + int32_t windowDim[CC_DIM_MAX - 2]; + int32_t padding[CC_DIM_MAX - 2]; + int32_t stride[CC_DIM_MAX - 2]; + int32_t dataMode; + int32_t ceilMode; + ccQuantizeDescriptor_t quantInfo; + ccPooingFwdAlgo_t algo; +} ccPooling_t; + +/** + * @ingroup dnn + * @brief struct define of activation operator + */ +typedef struct tagCcActivation { + ccActivationMode_t mode; + ccNanPropagation_t reluNanOpt; + double coef; /* ceiling for clipped RELU, alpha for ELU */ + ccActivationPara_u activationPara; +} ccActivation_t; + +/** + * @ingroup dnn + * @brief struct define of svdf operator + */ +typedef struct tagCcSvdf { + ccTensorFormat_t format; + ccDataType_t dataType; + uint32_t batches; + uint32_t features; + uint32_t rank; + uint32_t inputSize; + uint32_t memorySize; +} ccSvdf_t; + +/** + * @ingroup dnn + * @brief struct define of svdf operator + */ +typedef struct tagCcHashTableLookup { + ccTensorFormat_t format; + ccDataType_t lookupType; + ccDataType_t keyType; + ccDataType_t valueType; + ccDataType_t outputType; + ccDataType_t hitsType; + uint32_t lookups; + uint32_t keys; + uint32_t rows; + uint32_t features; + uint16_t valueScale; + uint16_t outputScale; + uint16_t valueOffset; + uint16_t outputOffset; +} ccHashTableLookup_t; + +/** + * @ingroup dnn + * @brief struct define of prelu operator + */ +typedef struct tagCcPRelu { + ccNanPropagation_t reluNanOpt; + int32_t slopeCount; + bool channelShared; +} ccPRelu_t; + +/** + * @ingroup dnn + * @brief struct define of crop operator + */ +typedef struct tagCcCrop { + int32_t startAxis; + int32_t offset[CC_DIM_MAX]; + int32_t offsetCnt; +} ccCrop_t; + +/** + * @ingroup dnn + * @brief struct define of SpatialTransformer operator + */ +typedef struct tagCcSpatialTransformer { + ccSamplerType_t samplerType; + ccDataType_t dataType; + int32_t dimCnt; + uint64_t dim[CC_DIM_MAX]; + uint64_t alignCorner; +} ccSpatialTransformer_t; + +/** + * @ingroup dnn + * @brief struct define of ShiftTransformer operator + */ +typedef struct tagCcShiftTransformer { + ccSamplerType_t samplerType; + double xPreDefined; + double yPreDefined; + bool xShift; + bool yShift; + int32_t gridH; + int32_t gridW; +} ccShiftTransformer_t; + +/** + * @ingroup dnn + * @brief struct define of FasterRcnnProposal operator + */ +typedef struct tagCcFasterRcnnProposal { + int32_t preNMStopK; + int32_t postNMStopK; + float nmsTresh; + float minSize; + float featStride; + float baseSize; + int32_t ratioCnt; + int32_t scaleCnt; + float *ratio; + float *scale; + int32_t imgH; + int32_t imgW; +} ccFasterRcnnProposal_t; + +/** + * @ingroup dnn + * @brief struct define of LRN operator + */ +typedef struct tagCcLRN { + ccLRNMode_t lrnMode; + int32_t lrnN; + double lrnAlpha; + double lrnBeta; + double lrnK; +} ccLRN_t; + +/** + * @ingroup dnn + * @brief struct define of instanceNorm + */ +typedef struct tagCcInstancenorm { + ccInstanceNormMode_t mode; + double epsilon; +} ccInstancenorm_t; + +/** + * @ingroup dnn + * @brief struct define of assignOp operator + */ +typedef struct tagCcAssignOp { + ccAssignOpMode_t assignOpMode; +} ccAssignOp_t; + +/** + * @ingroup dnn + * @brief struct define of arcSinCos operator + */ +typedef struct tagCcArcSinCos { + ccArcSinCosMode_t arcSinCosMode; +} ccArcSinCos_t; + +/** + * @ingroup dnn + * @brief struct define of Detectpostprocess operator + */ +typedef struct tagCcDetectpostprocess { + int32_t numClasses; + float confThreshold; + float nmsThreshold; + int32_t outTopK; + float bboxRegWeightsDx; + float bboxRegWeightsDy; + float bboxRegWeightsDw; + float bboxRegWeightsDh; +} ccDetectpostprocess_t; +/** + * @ingroup dnn + * @brief struct define of FasterRcnnDetectionOutput operator + */ +typedef struct tagCcFasterRcnnDetectionOutput { + int32_t numClasses; + float nmsThreshold; + float postConfThreshold; + int32_t imgH; + int32_t imgW; + int32_t batchSize; +} ccFasterRcnnDetectionOutput_t; + +/** + * @ingroup dnn + * @brief struct define of SsdDetectionOutput operator + */ +typedef struct tagCcSsdDetectionOutput { + int32_t numClasses; + int32_t backgroundLabelId; + double preConfThreshold; + int32_t preTopK; + double nmsThreshold; + double nmsEta; + ccBoxCodeType_t codeType; + int32_t outTopK; + bool shareLocation; + bool varianceEncodedInTarget; + uint32_t boxTypeNum; + float var[4]; + uint32_t variance_num; +} ccSsdDetectionOutput_t; + +/** + * @ingroup dnn + * @brief struct define of RefinedetDetectionOutput operator + */ +typedef struct tagCcRefinedetDetectionOutput { + int32_t numClasses; + int32_t backgroundLabelId; + double preConfThreshold; + int32_t preTopK; + double nmsThreshold; + double nmsEta; + ccBoxCodeType_t codeType; + int32_t outTopK; + bool shareLocation; + bool varianceEncodedInTarget; + uint32_t boxTypeNum; + float var[4]; + uint32_t variance_num; + double objectness_score; +} ccRefinedetDetectionOutput_t; + +/** + * @ingroup dnn + * @brief struct define of MsrGenerateRpnProposals operator + */ +typedef struct tagCcMsrGenerateRpnProposals { + int32_t preNmsTopK; + int32_t postNmsTopK; + float nmsThreshold; + float rpnMiniSize; + int32_t imgH; + int32_t imgW; + uint32_t boxTypeNum; + float scoreThreshold; +} ccMsrGenerateRpnProposals_t; + +/** + * @ingroup dnn + * @brief struct define of RetinaPostprocessor operator + */ +typedef struct tagCcRetinaPostprocessor { + int32_t numClasses; + int32_t maxDetections; + float nmsThreshold; + float scoreThreshold; + int32_t imgH; + int32_t imgW; + uint32_t boxTypeNum; + float mean[4]; + int32_t meanNum; + float std[4]; + int32_t stdNum; + int32_t outputNum; + bool ocrFlag; +} ccRetinaPostprocessor_t; + +/** + * @ingroup dnn + * @brief struct define of GenerateSsdAnchors operator + */ +typedef struct tagCcGenerateSsdAnchors { + int32_t featureMapShapeList[20]; + uint32_t featureMapShapeListSize; + int32_t boxSpecsNum[10]; + uint32_t boxSpecsNumSize; + float scales[10]; + uint32_t scalesNum; + float aspectRatios[10]; + uint32_t aspectRatiosNum; + int32_t baseAnchorSize[2]; + uint32_t baseAnchorSizeNum; + int32_t anchorStride[2]; + uint32_t anchorStrideNum; + int32_t anchorOffset[2]; + uint32_t anchorOffsetNum; + bool reduceBoxesInLowestLayer; + float minScale; + float maxScale; + int32_t imgH; + int32_t imgW; +} ccGenerateSsdAnchors_t; + +/** + * @ingroup dnn + * @brief struct define of MscnnBoxOutput operator + */ +typedef struct tagCcMscnnBoxOutput { + double fgThreshold; + double nmsThreshold; + ccNmsType_t nmsType; + int32_t fieldH[CC_MAX_INPUT_CNT]; + int32_t fieldW[CC_MAX_INPUT_CNT]; + int32_t downsampleRate[CC_MAX_INPUT_CNT]; + int32_t defaultBoxCnt; + double fieldWhr; + double fieldXyr; + int32_t maxNmsNum; + int32_t maxPostNmsNum; + double minSize; +} ccMscnnBoxOutput_t; + +/** + * @ingroup dnn + * @brief struct define of NMS operator + */ +typedef struct tagCcNms { + int32_t numClasses; + int32_t backgroundLabelId; + double preConfThreshold; + int32_t preTopK; + double nmsThreshold; + double nmsEta; + int32_t postTopK; + int32_t outTopK; + double postConfThreshold; + bool shareLocation; +} ccNms_t; + +/** + * @ingroup dnn + * @brief struct define of NMS/MultiClassNMS operator + */ +typedef struct tagCcMultiClassNms { + uint64_t numClasses; + float objThreshold; + float nmsThreshold; + float clsThreshold; + bool normal; + uint64_t coorType; +} ccCcMultiClassNms_t; + +/** + * @ingroup dnn + * @brief struct define of YoloDetectionOutput operator + */ +typedef struct tagCcYoloDetectionOutput { + ccYoloVersion_t yoloVersion; + uint32_t netH; + uint32_t netW; + uint32_t postTopK; + uint32_t classes; + float nmsThreshold; + float iouThreDecay; + float coorScaleFactor; + bool relative; + float objThreshold; + float clsThreshold; + uint32_t biasNum; + float *bias; +} ccYoloDetectionOutput_t; + +/** + * @ingroup dnn + * @brief struct define of GetRegionBox operator + */ +#ifndef CC_MAX_YOLO_BIAS_NUM +#define CC_MAX_YOLO_BIAS_NUM (16) +#endif + +typedef struct tagCcGetRegionBox { + uint32_t biasNum; + uint32_t H; + uint32_t W; + float bias[CC_MAX_YOLO_BIAS_NUM]; +} ccGetRegionBox_t; + +/** + * @ingroup dnn + * @brief struct define of CorrectBoxes operator + */ +typedef struct tagCorrectBoxes { + uint32_t netW; + uint32_t netH; + bool relative; +} ccCorrectBoxes_t; + +/** + * @ingroup dnn + * @brief struct define of ClsProb operator + */ +typedef struct tagClsProb { + float objThreshold; +} ccClsProb_t; + +/** + * @ingroup dnn + * @brief struct define of SsdPriorBox operator + */ +typedef struct tagCcSsdPriorBox { + ccBoxCodeType_t codeType; + double *minSize; + int32_t minSizeNum; + double *maxSize; + int32_t maxSizeNum; + double *aspectRatio; + int32_t aspectRatioNum; + double *variance; + int32_t varianceNum; + int32_t imgH; + int32_t imgW; + double stepH; + double stepW; + double offset; + bool flip; + bool clip; +} ccSsdPriorBox_t; + +/** + * @ingroup dnn + * @brief struct define of Yolo2Region operator + */ +typedef struct tagCcYolo2Region { + ccSoftmaxTree_t softmaxTree; + bool softmax; + bool background; + bool treeSoftmax; +} ccYolo2Region_t; + +/** + * @ingroup dnn + * @brief struct define of YoloRegion operator + */ +typedef struct tagCcYoloRegion { + ccSoftmaxTree_t softmaxTree; + bool softmax; + bool background; + bool treeSoftmax; + int32_t classes; + int32_t coords; + int32_t boxes; + ccYoloVersion_t yoloV; +} ccYoloRegion_t; + +/** + * @ingroup dnn + * @brief struct define of power operator + */ +typedef struct tagCcPower { + float scale; + float shift; + float power; +} ccPower_t; + +/** + * @ingroup dnn + * @brief struct define of exp operator + */ +typedef struct tagCcExp { + ccDataType_t dataType; + uint32_t paramCnt; +} ccExp_t; + +/** + * @ingroup dnn + * @brief struct define of exp operator + */ +typedef struct tagCcLog { + ccDataType_t dataType; + uint32_t paramCnt; +} ccLog_t; + +/** + * @ingroup dnn + * @brief struct define of pow operator + */ +typedef struct tagCcPow { + ccDataType_t dataType; + uint32_t paramCnt; +} ccPow_t; + +/** + * @ingroup dnn + * @brief struct define of padv2 operator + */ +typedef struct tagCcPadV2 { + ccPadMode_t padMode; + void *padValue; + ccDataType_t padValueType; + int32_t padDimCnt; + int32_t padShapeLow[CC_DIM_MAX]; + int32_t padShapeHigh[CC_DIM_MAX]; +} ccPadV2_t; + +/** + * @ingroup dnn + * @brief struct define of psROIPooling operator + */ +typedef struct tagCcPsRoiPooling { + ccPoolingMode_t poolingMode; + int32_t pooledH; + int32_t pooledW; + float spatialScale; + float padRatio; + int32_t groupSize; + int32_t outputDim; +} ccPsRoiPooling_t; + +/** + * @ingroup dnn + * @brief struct define of RoIAlign operator + */ +typedef struct tagCcRoiAlign { + int32_t pooledH; + int32_t pooledW; + float spatialScale; + int32_t samplingRatio; +} ccRoiAlign_t; + +/** + * @ingroup dnn + * @brief struct define of RoiInterpPooling operator + */ +typedef struct tagCcRoiInterpPooling { + int32_t pooledH; + int32_t pooledW; + int32_t poolKernelH; + int32_t poolKernelW; + int32_t pooledTailH; + int32_t pooledTailW; + float spatialScaleH; + float spatialScaleW; +} ccRoiInterpPooling_t; + +/** + * @ingroup dnn + * @brief struct define of DetectionFull3DOutput operator + */ +typedef struct tagCcDetectionFull3DOutput { + int32_t imageWidth; + int32_t imageHeight; + int32_t numAngleBins; + float trcMarginRatioX; + float trcMarginRatioY; + int32_t pitchRangeD; + int32_t pitchPresetD; + float mountHeight; + int32_t visiblenessBins; + float meanVisibleness; + bool discreteVisibleness; +} ccDetectionFull3DOutput_t; + +/** + * @ingroup dnn + * @brief struct define of MsrFastRcnnPredictions operator + */ +typedef struct tagMsrFastRcnnPredictions { + int32_t numClasses; // num of classes + float scoreThreshold; // the threshold of the score + double nmsThreshold; // the threshold of nms + int32_t postTopK; + int32_t outTopK; + int32_t imgH; // the height of image + int32_t imgW; // the width of image +} ccMsrFastRcnnPredictions_t; + +typedef struct tagCcResizeBilinear { + ccResizeOutputDimMode_t resizeOutputDimMode; + bool alignCorners; + int32_t zoom_factor; + int32_t shrink_factor; + int32_t height; + int32_t width; + int32_t pad_begin; + int32_t pad_end; +} ccResizeBilinear_t; + +typedef struct tagCcResizeNearestNeighbor { + bool alignCorners; + int32_t height; + int32_t width; +} ccResizeNearestNeighbor_t; + +typedef struct tagCcEltwise { + ccQuantize_t *quantInfo; + bool reluFlag; +} ccEltwise_t; + +typedef struct tagCcBatchNorm { + bool reluFlag; +} ccBatchNorm_t; + +typedef struct tagCcPad { + ccPadMode_t padMode; + float padValue; + int32_t htoppad; // padLow[0] + int32_t hbottompad; // padHigh[0] + int32_t wleftpad; // padLow[1] + int32_t wrightpad; // padHigh[1] +} ccPad_t; + +typedef struct tagCcSubCondition { + uint32_t BaseCondValue[4]; + ccCMPType_t condType[4]; + ccResultType_t resultType; +} ccSubCondition; + +typedef struct tagCcShapeClassifyCond { + uint32_t subConditionNum; + ccResultType_t resultType; + uint32_t true_value; + ccSubCondition subCond[2]; +} ccShapeClassifyCond; + +#ifndef CC_SHAPE_CLASSIFY_CONDITION_NUM +#define CC_SHAPE_CLASSIFY_CONDITION_NUM (8) +#endif + +typedef struct tagCcShapeClassify { + uint32_t shapeClassifyConditionNum; + uint32_t defaultValue; + ccShapeClassifyCond shapeClassifyCond[CC_SHAPE_CLASSIFY_CONDITION_NUM]; +} ccShapeClassify_t; + +/** + * @ingroup dnn + * @bref struct define of square operator + */ +typedef struct tagCcSquare { + ccSquareMode_t mode; +} ccSquare_t; + +/* + * @ingroup dnn + * @brief operation of segment reduction + */ +typedef enum { + CC_SEGMENT_REDUCTION_OP_SUM = 0, /**< sum */ + CC_SEGMENT_REDUCTION_OP_INVALID +} ccSegmentReductionOpType_t; + +typedef struct tagCcFillParam { + // The filler type. + ccFillOpType_t fillType; + ccDataType_t valueDatatype; + const void *value; // the value in constant fill + const void *min; // the min value in uniform fill + const void *max; // the max value in uniform fill + const void *mean; // the mean value in Gaussian fill + const void *std; // the std value in Gaussian fill + // the seed used to generate data in Gaussian and uniform fill + int64_t seed1; + int64_t seed2; +} ccFillParam_t; + +typedef struct tagNonMaxSuppression { + ccDataType_t dataType; + uint32_t paraCount; +} ccNonMaxSuppression_t; + +typedef struct tagCcArgmaxmin { + int32_t axisType; + bool outMaxVal; + int64_t topK; + int64_t reduceSize; + int64_t reduceStride; + int64_t axis; + bool keepDims; +} ccArgmaxmin_t; + +typedef struct tagUpsamplePara { + int32_t scale; + int32_t scaleHeight; + int32_t scaleWidth; + int32_t upsampleHeight; + int32_t upsampleWidth; + bool padOutHeight; + bool padOutWidth; +} ccUpsamplePara_t; + +typedef struct tagCcConcatFive2Four_t { + ccTransForLossMode_t mode; + uint32_t classNum; +} ccConcatFive2Four_t; + +}; // namespace cce +#endif // DNN_STRUCT_BASE_HPP__ diff --git a/metadef/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/metadef/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h new file mode 100644 index 00000000..50b39d91 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -0,0 +1,129 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FWK_ADPT_STRUCT_H__ +#define FWK_ADPT_STRUCT_H__ + +#include + +namespace aicpu { +namespace FWKAdapter { + +// API RETURN CODE +enum FWKAdptAPIRetCode { + FWK_ADPT_SUCCESS = 0, // success + FWK_ADPT_NOT_INIT = 1, // not init + FWK_ADPT_ALLOC_FAILED = 2, // allocate memory failed + FWK_ADPT_PARAM_INVALID = 3, // invalid input param + FWK_ADPT_PARAM_PARSE_FAILED = 4, // parase input param failed + FWK_ADPT_NATIVE_ERROR = 5, // error code + FWK_ADPT_NOT_SUPPORT_OPTYPE = 6, // unsupport operate type + FWK_ADPT_INTERNAL_ERROR = 7, // adpter internal error + FWK_ADPT_NOT_SUPPORT_DATATYPE = 8, // unsupport input/output data type + FWK_ADPT_KERNEL_ALREADY_RUNING = 9, // kernel already runing, not support parallel run + FWK_ADPT_SESSION_NOT_EXIST = 10, // session id not exist + FWK_ADPT_SESSION_ALREADY_EXIST = 11, // session id alread exist for create session + FWK_ADPT_NATIVE_END_OF_SEQUENCE = 12, // end of sequence + FWK_ADPT_EXTEND_TYPE_NOT_EXIST = 13, // extend info type not exist + FWK_ADPT_UNKNOWN_ERROR = 99 // unknown error code +}; + +// FWKAdapter operate type +// Notice: add new operate type need check with OMM, and make sure append to the end line. +enum FWKOperateType { + FWK_ADPT_SESSION_CREATE = 0, + FWK_ADPT_KERNEL_RUN, + FWK_ADPT_KERNEL_DESTROY, + FWK_ADPT_SESSION_DESTROY, + FWK_ADPT_SINGLE_OP_RUN, + FWK_ADPT_KERNEL_RUN_NO_SESS, +}; + +// Extend Info type for task +enum FWKTaskExtInfoType { + FWK_ADPT_EXT_SHAPE_TYPE = 0, + FWK_ADPT_EXT_INPUT_SHAPE, + FWK_ADPT_EXT_OUTPUT_SHAPE, + FWK_ADPT_EXT_UPDATE_ADDR, + FWK_ADPT_EXT_OP_NAME, + FWK_ADPT_EXT_SESSION_INFO, + FWK_ADPT_EXT_INVALID +}; + +enum FWKExtUpdateAddrType { + FWK_ADPT_UPDATE_NULL = 0, + FWK_ADPT_UPDATE_INPUT, + FWK_ADPT_UPDATE_OUTPUT, + FWK_ADPT_UPDATE_INPUT_OUTPUT +}; + +#pragma pack(push, 1) +// API Parameter Structure +struct StrFWKKernel { + FWKOperateType opType; + uint64_t sessionID; // unique + + uint64_t stepIDAddr; // step id addr + uint64_t kernelID; // run kernel id, unique in session + uint64_t nodeDefLen; // nodeDef protobuf len + uint64_t nodeDefBuf; // NodeDef protobuf offset addr, need convert to void* + uint64_t funDefLibLen; // FunctionDefLibrary protobuf len + uint64_t funDefLibBuf; // FunctionDefLibrary protobuf addr which use in NodeDef, need convert to void* + + uint64_t inputOutputLen; // InputOutput shap protobuf len + uint64_t inputOutputBuf; // InputOutput shap protobuf addr, need convert to void* + uint64_t workspaceBaseAddr; // Workspace base addr, need convert to void* + uint64_t inputOutputAddr; // InputOutput addr, need convert to void* + + uint64_t extInfoLen; // extend info total length + uint64_t extInfoAddr; // extend info addr, ExtInfo structure +}; +#pragma pack(pop) + +typedef StrFWKKernel FWKOperateParam; + +// Extent info ShapeAndType +const uint32_t kMaxShapeDims = 8; +#pragma pack(push, 1) +struct ShapeAndType { + int32_t type; + int64_t dims[kMaxShapeDims]; +}; +#pragma pack(pop) + +// Extend info structure for extInfoAddr +const uint32_t kExtInfoHeadSize = 8; + +#pragma pack(push, 1) +struct ExtInfo { + int32_t infoType; // extend type + uint32_t infoLen; // length for infoMsg + char infoMsg[0]; // extend value +}; +#pragma pack(pop) + +#pragma pack(push, 1) +struct ResultSummary { + uint64_t shape_data_ptr; // shape data addr, need convert to void* + uint64_t shape_data_size; // num of dims + uint64_t raw_data_ptr; // raw data addr, need convert to void* + uint64_t raw_data_size; // size of raw data +}; +#pragma pack(pop) +} // end namespace FWKAdapter +} // namespace aicpu + +#endif // FWK_ADPT_STRUCT_H__ diff --git a/metadef/third_party/fwkacllib/inc/cce/l2fusion_struct.hpp b/metadef/third_party/fwkacllib/inc/cce/l2fusion_struct.hpp new file mode 100644 index 00000000..fa5a95c9 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/l2fusion_struct.hpp @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef L2FUSION_STRUCT_HPP_ +#define L2FUSION_STRUCT_HPP_ + +#include +#include +#include "runtime/kernel.h" + +#define L2_DYNAMIC_SPLIT_NUM + +using namespace std; + +namespace fusion { + +typedef struct tagL2Data { + uint32_t l2Index; + uint64_t l2Addr; + uint64_t l2PageNum; +} L2Data_t; + +typedef std::map L2DataMap_t; // the key is ddr addr +typedef std::pair L2DataPair_t; // the key is ddr addr + +typedef struct TagTaskL2Info { + string nodeName; + rtL2Ctrl_t l2ctrl; + + L2DataMap_t input; + L2DataMap_t output; + uint32_t isUsed; +} TaskL2Info_t; + +typedef std::map TaskL2InfoMap_t; // the key is nodeId +typedef std::pair TaskL2InfoPair_t; // the key is nodeId + +typedef std::map TaskL2InfoFEMap_t; // the key is nodeName +typedef std::pair TaskL2InfoFEPair_t; // the key is nodeName + +} // namespace fusion + +#endif // L2FUSION_STRUCT_HPP_ diff --git a/metadef/third_party/fwkacllib/inc/cce/optimizer/fusion_engine.h b/metadef/third_party/fwkacllib/inc/cce/optimizer/fusion_engine.h new file mode 100644 index 00000000..299998e3 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/optimizer/fusion_engine.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FUSION_ENGINE_HPP_ +#define FUSION_ENGINE_HPP_ + +#include "cce/cce.h" +#include "graph/compute_graph.h" +#include "proto/task.pb.h" + +#include +#include + +using namespace domi; +using namespace std; + +namespace fusion { +enum { + FUSION_STATUS_SUCCESS = 0, + FUSION_STATUS_FAIL = 1, +}; + +typedef struct { + uint64_t weightSize; + uint64_t memorySize; + uint8_t *dataMemBase; + uint8_t *weightMemBase; + uint32_t l2Enable; // 1 //1 - enable l2 buffer allocation, 0 - disable l2 buffer allocation + uint32_t fusionEnable; // 1 // 1 - enable buffer fusion, 0 - disable buffer fusion +} ModelRes; + +static const std::string SCOPE_ID_ATTR = "fusion_scope"; +static const std::string L2FUSION_DYNAMIC_CONVERGE_OP = "l2fusion_dynamic_converge_op"; +static const std::string L2FUSION_DYNAMIC_SPLIT_NUM = "l2fusion_dynamic_split_num"; +static const std::string FUSION_VIRTUAL_OP = "fusion_virtual_op"; +static const std::string FUSION_MULTI_BATCH_STRIDE = "fusion_multi_bathc_stride"; + +#define TVM_TYPE 1 + +typedef std::map> kScopeNodeMap_t; +typedef std::pair> kScopeNodePair_t; + +uint32_t BufferFusion(ge::ComputeGraphPtr origGraph, ge::ComputeGraphPtr fusionGraph, bool enable_l2dynamic = true); +uint32_t BufferFusionTrain(ge::ComputeGraphPtr origGraph, ge::ComputeGraphPtr fusionGraph); +uint32_t GraphFusion(ge::ComputeGraphPtr origGraph, ge::ComputeGraphPtr fusionGraph); +uint32_t FusionTaskBuild(cce::ccHandle_t ccHandle, ge::ComputeGraphPtr fusionGraph, ge::Buffer &buffer, + ModelRes &modelRes, std::vector &task_def_list_); +void FusionTaskBuildComplete(std::vector cchandleList); +uint32_t GraphFusionTrain(ge::ComputeGraphPtr origGraph, ge::ComputeGraphPtr fusionGraph); +} // namespace fusion + +#endif // FUSION_ENGINE_HPP_ diff --git a/metadef/third_party/fwkacllib/inc/cce/taskdown_api.h b/metadef/third_party/fwkacllib/inc/cce/taskdown_api.h new file mode 100644 index 00000000..2323aaa7 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/taskdown_api.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TASKDOWN_API_H_ +#define TASKDOWN_API_H_ + +#include +#include +#include "cce/cce.h" +#include "l2fusion_struct.hpp" +#include "taskdown_common.hpp" + +namespace cce { + +#define CC_FUSION_OP_MAX 32 + +typedef struct tagOpAddrsInfo { + void *addrPos; + uintptr_t addrData; +} ccOpAddrsInfo; + +#ifdef __cplusplus +extern "C" { +#endif + +ccStatus_t ccUpdateKernelArgs(ccOpContext &opContext, uint64_t dataBaseAddr, uint64_t weightBaseAddr, + uint64_t variableBaseAddr, void *argsAddr, uint64_t argsSize, void *l2ctrlAddr); + +#ifdef __cplusplus +} +#endif + +ccStatus_t ccGetKernelArgsAddrs(ccOpContext &opContext, void *argsAddr, uint64_t argsSize, void *l2ctrlAddr, + std::vector &opAddrsInfo); + +ccStatus_t ccSetKernelArgs(std::vector &dateInfo); + +ccStatus_t ccGetKernelTypeByOpId(uint32_t opId, ccKernelType &kernelType); + +} // namespace cce +#endif // TASKDOWN_API_H_ diff --git a/metadef/third_party/fwkacllib/inc/cce/taskdown_common.hpp b/metadef/third_party/fwkacllib/inc/cce/taskdown_common.hpp new file mode 100644 index 00000000..3ecea523 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/cce/taskdown_common.hpp @@ -0,0 +1,107 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TASKDOWN_COMMON_H_ +#define TASKDOWN_COMMON_H_ + +#include +#include "cce/cce_def.hpp" +#include "common/attr_list.hpp" +#include "l2fusion_struct.hpp" + +namespace cce { + +#define CC_FUSION_OP_MAX 32 + +typedef enum tagccKernelType { + CCE_AI_CORE = 0, /* cce aicore */ + CCE_AI_CPU = 1, /* cce aicpu */ + TE = 2, /* te operator*/ + CUSTOMIZED = 3, /* customized operator */ + TE_AI_CORE = 4, /* te aicore operator*/ + TE_AI_CPU = 5, /* te aicpu operator */ + AI_CPU = 6, /* aicpu */ + CUST_AI_CPU = 7, /* custom aicpu*/ + INVALID = 8, /* unknown kernel type */ +} ccKernelType; + +typedef struct tagOpContext { + ccKernelType kernelType; + uint32_t opId; + uint32_t kernelFuncId; + uint32_t opIndex; + uint32_t opCount; + uint32_t opIndex2[CC_FUSION_OP_MAX]; + bool isFlowtable; + uint16_t *argsOffset; + uint32_t argsCount; + uint64_t genDataBaseAddr; + uint64_t genDataBaseSize; + uint64_t genWeightBaseAddr; + uint64_t genWeightBaseSize; + uint64_t genVariableBaseAddr; + uint64_t genVariableBaseSize; + uint64_t l2ctrlSize; +} ccOpContext; + +typedef struct tagOpReadCount { + bool isEnable; + std::map tensorRc; +} ccOpReadCount; + +typedef enum tagTaskDownKernelIdMode { + CC_TASKDOWN_RESERVED = 0, + CC_TASKDOWN_ROIPOOLING, + CC_TASKDOWN_ROIPOOLING_PERF, + CC_TASKDOWN_ROIALIGN, + CC_TASKDOWN_ROIALIGN_PERF, + CC_TASKDOWN_FC, + CC_TASKDOWN_FC_COMPRESS, + CC_TASKDOWN_SOFTMAX_LOWEST, + CC_TASKDOWN_ROIALIGN_FP16, + CC_TASKDOWN_RESIZE_NEAREST_NEIGHBOR, + CC_TASKDOWN_RESIZE_NEAREST_NEIGHBOR_COMMON, +} ccTaskDownKernelIdMode_t; + +ccStatus_t GetStream(ccHandle_t handle, rtStream_t *streamId); + +ccStatus_t ccClearOpMap(ccHandle_t handle); + +ccStatus_t ccSetKernelOpMap(ccHandle_t handle); + +ccStatus_t ccSetKernelContext(ccHandle_t handle, uint32_t opId, AttrList &attrList, bool isFlowtable, + ccKernelType kernelType, void *pgraph); + +ccStatus_t ccGetKernelContext(rtStream_t streamId, ccOpContext &opContext); + +ccStatus_t ccGetKernelTypeByOpId(uint32_t opId, ccKernelType &kernelType); + +ccStatus_t ccSetStreamL2Map(ccHandle_t handle, fusion::TaskL2InfoMap_t &l2AllocRes); + +ccStatus_t ccGetStreamL2Map(rtStream_t streamId, uint32_t opIndex, fusion::TaskL2Info_t *&l2Data); + +ccStatus_t ccSetOpIndex(ccHandle_t handle, uint32_t opIndex); + +ccStatus_t ccGetOpIndex(ccHandle_t handle, uint32_t &opIndex); + +ccStatus_t ccGetOpIndexByStream(rtStream_t streamId, uint32_t &opIndex); + +ccStatus_t ccClearStreamL2Map(ccHandle_t handle); + +ccStatus_t ccGetKernelReadCount(rtStream_t streamId, ccOpReadCount &rc); + +} // namespace cce +#endif // TASKDOWN_COMMON_H_ diff --git a/metadef/third_party/fwkacllib/inc/mmpa/mmpa_api.h b/metadef/third_party/fwkacllib/inc/mmpa/mmpa_api.h new file mode 100644 index 00000000..38a689ee --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/mmpa/mmpa_api.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _MMPA_API_H_ +#define _MMPA_API_H_ + +#define LINUX 0 +#define WIN 1 + +#if(OS_TYPE == LINUX) //lint !e553 + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#ifdef FUNC_VISIBILITY +#define MMPA_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define MMPA_FUNC_VISIBILITY +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include "./sub_inc/mmpa_typedef_linux.h" +#include "./sub_inc/mmpa_linux.h" + +#endif + + +#if(OS_TYPE == WIN) //lint !e553 + +#ifdef FUNC_VISIBILITY +#define MMPA_FUNC_VISIBILITY _declspec(dllexport) +#else +#define MMPA_FUNC_VISIBILITY +#endif + +#include +#include +#include "Windows.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "shlwapi.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include "sub_inc/mmpa_typedef_win.h" +#include "sub_inc/mmpa_win.h" + +#pragma comment(lib, "ws2_32.lib") +#pragma comment(lib, "mswsock.lib") +#pragma comment(lib, "Kernel32.lib") +#pragma comment(lib, "shlwapi.lib") +#pragma comment(lib, "wbemuuid.lib") +#pragma comment(lib, "Iphlpapi.lib") +#endif + +#endif // MMPA_API_H_ + diff --git a/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h new file mode 100644 index 00000000..ad48f70b --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h @@ -0,0 +1,559 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMPA_LINUX_MMPA_LINUX_H +#define MMPA_LINUX_MMPA_LINUX_H + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus + +#define MMPA_MACINFO_DEFAULT_SIZE 18 +#define MMPA_CPUDESC_DEFAULT_SIZE 64 + +typedef pthread_t mmThread; +typedef pthread_mutex_t mmMutex_t; +typedef pthread_cond_t mmCond; +typedef pthread_mutex_t mmMutexFC; +typedef pthread_rwlock_t mmRWLock_t; +typedef signed int mmProcess; +typedef int mmPollHandle; +typedef int mmPipeHandle; +typedef int mmFileHandle; +typedef int mmComPletionKey; +typedef int mmCompletionHandle; +typedef int mmErrorMsg; +typedef int mmFd_t; + +typedef VOID *mmExitCode; +typedef key_t mmKey_t; +typedef int mmMsgid; +typedef struct dirent mmDirent; +typedef struct dirent mmDirent2; +typedef struct shmid_ds mmshmId_ds; +typedef int (*mmFilter)(const mmDirent *entry); +typedef int (*mmFilter2)(const mmDirent2 *entry); +typedef int (*mmSort)(const mmDirent **a, const mmDirent **b); +typedef int (*mmSort2)(const mmDirent2 **a, const mmDirent2 **b); +typedef size_t mmSize_t; +typedef off_t mmOfft_t; +typedef pid_t mmPid_t; +typedef long MM_LONG; + +typedef VOID *(*userProcFunc)(VOID *pulArg); + +typedef struct { + userProcFunc procFunc; // Callback function pointer + VOID *pulArg; // Callback function parameters +} mmUserBlock_t; + +typedef struct { + const char *dli_fname; + void *dli_fbase; + const char *dli_sname; + void *dli_saddr; + size_t dli_size; /* ELF only */ + int dli_bind; /* ELF only */ + int dli_type; +} mmDlInfo; + +typedef struct { + int wSecond; // Seconds. [0-60] (1 leap second) + int wMinute; // Minutes. [0-59] + int wHour; // Hours. [0-23] + int wDay; // Day. [1-31] + int wMonth; // Month. [1-12] + int wYear; // Year + int wDayOfWeek; // Day of week. [0-6] + int tm_yday; // Days in year.[0-365] + int tm_isdst; // DST. [-1/0/1] + long int wMilliseconds; // milliseconds +} mmSystemTime_t; + +typedef sem_t mmSem_t; +typedef struct sockaddr mmSockAddr; +typedef socklen_t mmSocklen_t; +typedef int mmSockHandle; +typedef timer_t mmTimer; +typedef pthread_key_t mmThreadKey; + +typedef int mmOverLap; + +typedef ssize_t mmSsize_t; +typedef size_t mmSize; // size + +typedef struct { + UINT32 createFlag; + INT32 oaFlag; +} mmCreateFlag; + +typedef struct { + VOID *sendBuf; + INT32 sendLen; +} mmIovSegment; +typedef struct in_addr mmInAddr; + +typedef struct { + VOID *inbuf; + INT32 inbufLen; + VOID *outbuf; + INT32 outbufLen; + mmOverLap *oa; +} mmIoctlBuf; + +typedef int mmAtomicType; +typedef int mmAtomicType64; + +typedef enum { + pollTypeRead = 1, // pipe read + pollTypeRecv, // socket recv + pollTypeIoctl, // ioctl +} mmPollType; + +typedef struct { + mmPollHandle handle; // The file descriptor or handle of poll is required + mmPollType pollType; // Operation type requiring poll + // read or recv or ioctl + INT32 ioctlCode; // IOCTL operation code, dedicated to IOCTL + mmComPletionKey completionKey; // The default value is blank, which is used in windows + // The data used to receive the difference between which handle is readable +} mmPollfd; + +typedef struct { + VOID *priv; // User defined private content + mmPollHandle bufHandle; // Value of handle corresponding to buf + mmPollType bufType; // Data types polled to + VOID *buf; // Data used in poll + UINT32 bufLen; // Data length used in poll + UINT32 bufRes; // Actual return length +} mmPollData, *pmmPollData; + +typedef VOID (*mmPollBack)(pmmPollData); + +typedef struct { + INT32 tz_minuteswest; // How many minutes is it different from Greenwich + INT32 tz_dsttime; // type of DST correction +} mmTimezone; + +typedef struct { + LONG tv_sec; + LONG tv_usec; +} mmTimeval; + +typedef struct { + MM_LONG tv_sec; + MM_LONG tv_nsec; +} mmTimespec; + +typedef struct { + ULONGLONG totalSize; + ULONGLONG freeSize; + ULONGLONG availSize; +} mmDiskSize; + +#define mmTLS __thread +typedef struct stat mmStat_t; +typedef struct stat64 mmStat64_t; +typedef mode_t mmMode_t; + +typedef struct option mmStructOption; + +typedef struct { + char addr[MMPA_MACINFO_DEFAULT_SIZE]; // ex:aa-bb-cc-dd-ee-ff\0 +} mmMacInfo; + +typedef struct { + char **argv; + INT32 argvCount; + char **envp; + INT32 envpCount; +} mmArgvEnv; + +typedef struct { + char arch[MMPA_CPUDESC_DEFAULT_SIZE]; + char manufacturer[MMPA_CPUDESC_DEFAULT_SIZE]; // vendor + char version[MMPA_CPUDESC_DEFAULT_SIZE]; // modelname + INT32 frequency; // cpu frequency + INT32 maxFrequency; // max speed + INT32 ncores; // cpu cores + INT32 nthreads; // cpu thread count + INT32 ncounts; // logical cpu nums +} mmCpuDesc; + +typedef mode_t MODE; + +typedef struct { + INT32 detachFlag; // Determine whether to set separation property 0, not to separate 1 + INT32 priorityFlag; // Determine whether to set priority 0 and not set 1 + INT32 priority; // Priority value range to be set 1-99 + INT32 policyFlag; // Set scheduling policy or not 0 do not set 1 setting + INT32 policy; // Scheduling policy value value + // MMPA_THREAD_SCHED_RR + // MMPA_THREAD_SCHED_OTHER + // MMPA_THREAD_SCHED_FIFO + INT32 stackFlag; // Set stack size or not: 0 does not set 1 setting + UINT32 stackSize; // The stack size unit bytes to be set cannot be less than MMPA_THREAD_STACK_MIN +} mmThreadAttr; + +#ifdef __ANDROID__ +#define S_IREAD S_IRUSR +#define S_IWRITE S_IWUSR +#endif + +#define mm_no_argument no_argument +#define mm_required_argument required_argument +#define mm_optional_argument optional_argument + +#define M_FILE_RDONLY O_RDONLY +#define M_FILE_WRONLY O_WRONLY +#define M_FILE_RDWR O_RDWR +#define M_FILE_CREAT O_CREAT + +#define M_RDONLY O_RDONLY +#define M_WRONLY O_WRONLY +#define M_RDWR O_RDWR +#define M_CREAT O_CREAT +#define M_BINARY O_RDONLY +#define M_TRUNC O_TRUNC +#define M_IRWXU S_IRWXU +#define M_APPEND O_APPEND + +#define M_IN_CREATE IN_CREATE +#define M_IN_CLOSE_WRITE IN_CLOSE_WRITE +#define M_IN_IGNORED IN_IGNORED + +#define M_OUT_CREATE IN_CREATE +#define M_OUT_CLOSE_WRITE IN_CLOSE_WRITE +#define M_OUT_IGNORED IN_IGNORED +#define M_OUT_ISDIR IN_ISDIR + +#define M_IREAD S_IREAD +#define M_IRUSR S_IRUSR +#define M_IWRITE S_IWRITE +#define M_IWUSR S_IWUSR +#define M_IXUSR S_IXUSR +#define FDSIZE 64 +#define M_MSG_CREAT IPC_CREAT +#define M_MSG_EXCL (IPC_CREAT | IPC_EXCL) +#define M_MSG_NOWAIT IPC_NOWAIT + +#define M_WAIT_NOHANG WNOHANG // Non blocking waiting +#define M_WAIT_UNTRACED \ + WUNTRACED // If the subprocess enters the suspended state, it will return immediately + // But the end state of the subprocess is ignored +#define M_UMASK_USRREAD S_IRUSR +#define M_UMASK_GRPREAD S_IRGRP +#define M_UMASK_OTHREAD S_IROTH + +#define M_UMASK_USRWRITE S_IWUSR +#define M_UMASK_GRPWRITE S_IWGRP +#define M_UMASK_OTHWRITE S_IWOTH + +#define M_UMASK_USREXEC S_IXUSR +#define M_UMASK_GRPEXEC S_IXGRP +#define M_UMASK_OTHEXEC S_IXOTH + +#define mmConstructor(x) __attribute__((constructor)) VOID x() +#define mmDestructor(x) __attribute__((destructor)) VOID x() + +#define MMPA_NO_ARGUMENT 0 +#define MMPA_REQUIRED_ARGUMENT 1 +#define MMPA_OPTIONAL_ARGUMENT 2 + +#define MMPA_MAX_PATH PATH_MAX +#define M_NAME_MAX MAX_FNAME + +#define M_F_OK F_OK +#define M_R_OK R_OK +#define M_W_OK W_OK + +#define MM_DT_DIR DT_DIR +#define MM_DT_REG DT_REG + +#define MMPA_STDIN STDIN_FILENO +#define MMPA_STDOUT STDOUT_FILENO +#define MMPA_STDERR STDERR_FILENO + +#define MMPA_RTLD_NOW RTLD_NOW +#define MMPA_RTLD_GLOBAL RTLD_GLOBAL +#define MMPA_RTLD_LAZY RTLD_LAZY +#define MMPA_RTLD_NODELETE RTLD_NODELETE + +#define MMPA_DL_EXT_NAME ".so" + +MMPA_FUNC_VISIBILITY INT32 mmCreateTask(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmJoinTask(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmMutexInit(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexTryLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexUnLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexDestroy(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondInit(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondLockInit(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondUnLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLockDestroy(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmRWLockInit(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRDLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmWRLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockDestroy(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmCondWait(mmCond *cond, mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondTimedWait(mmCond *cond, mmMutexFC *mutex, UINT32 milliSecond); +MMPA_FUNC_VISIBILITY INT32 mmCondNotify(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondNotifyAll(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondDestroy(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmGetPid(); +MMPA_FUNC_VISIBILITY INT32 mmGetTid(); +MMPA_FUNC_VISIBILITY INT32 mmGetPidHandle(mmProcess *processHandle); +MMPA_FUNC_VISIBILITY INT32 mmGetLocalTime(mmSystemTime_t *sysTime); +MMPA_FUNC_VISIBILITY INT32 mmGetSystemTime(mmSystemTime_t *sysTime); + +MMPA_FUNC_VISIBILITY INT32 mmSemInit(mmSem_t *sem, UINT32 value); +MMPA_FUNC_VISIBILITY INT32 mmSemWait(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemPost(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemDestroy(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmOpen(const CHAR *pathName, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmOpen2(const CHAR *pathName, INT32 flags, MODE mode); +MMPA_FUNC_VISIBILITY FILE *mmPopen(CHAR *command, CHAR *type); +MMPA_FUNC_VISIBILITY INT32 mmClose(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmPclose(FILE *stream); +MMPA_FUNC_VISIBILITY mmSsize_t mmWrite(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSsize_t mmRead(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSockHandle mmSocket(INT32 sockFamily, INT32 type, INT32 protocol); +MMPA_FUNC_VISIBILITY INT32 mmBind(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmListen(mmSockHandle sockFd, INT32 backLog); +MMPA_FUNC_VISIBILITY mmSockHandle mmAccept(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t *addrLen); +MMPA_FUNC_VISIBILITY INT32 mmConnect(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmCloseSocket(mmSockHandle sockFd); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketSend(mmSockHandle sockFd, VOID *sendBuf, INT32 sendLen, INT32 sendFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecv(mmSockHandle sockFd, VOID *recvBuf, INT32 recvLen, INT32 recvFlag); +MMPA_FUNC_VISIBILITY INT32 mmSocketSendTo(mmSockHandle sockFd, + VOID *sendMsg, + INT32 sendLen, + UINT32 sendFlag, + const mmSockAddr* addr, + INT32 tolen); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecvFrom(mmSockHandle sockFd, + VOID *recvBuf, + mmSize recvLen, + UINT32 recvFlag, + mmSockAddr* addr, + mmSocklen_t *FromLen); +MMPA_FUNC_VISIBILITY INT32 mmSAStartup(); +MMPA_FUNC_VISIBILITY INT32 mmSACleanup(); +MMPA_FUNC_VISIBILITY VOID *mmDlopen(const CHAR *fileName, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmDladdr(VOID *addr, mmDlInfo *info); +MMPA_FUNC_VISIBILITY VOID *mmDlsym(VOID *handle, const CHAR *funcName); +MMPA_FUNC_VISIBILITY INT32 mmDlclose(VOID *handle); +MMPA_FUNC_VISIBILITY CHAR *mmDlerror(); +MMPA_FUNC_VISIBILITY INT32 mmCreateAndSetTimer(mmTimer *timerHandle, + mmUserBlock_t *timerBlock, + UINT milliSecond, + UINT period); +MMPA_FUNC_VISIBILITY INT32 mmDeleteTimer(mmTimer timerHandle); +MMPA_FUNC_VISIBILITY INT32 mmStatGet(const CHAR *path, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmStat64Get(const CHAR *path, mmStat64_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmFStatGet(INT32 fd, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmMkdir(const CHAR *pathName, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmSleep(UINT32 milliSecond); + +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithAttr(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmGetProcessPrio(mmProcess pid); +MMPA_FUNC_VISIBILITY INT32 mmSetProcessPrio(mmProcess pid, INT32 processPrio); +MMPA_FUNC_VISIBILITY INT32 mmGetThreadPrio(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmSetThreadPrio(mmThread *threadHandle, INT32 threadPrio); +MMPA_FUNC_VISIBILITY INT32 mmAccess(const CHAR *pathName); +MMPA_FUNC_VISIBILITY INT32 mmAccess2(const CHAR *pathName, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmRmdir(const CHAR *pathName); + +MMPA_FUNC_VISIBILITY INT32 mmIoctl(mmProcess fd, INT32 ioctlCode, mmIoctlBuf *bufPtr); +MMPA_FUNC_VISIBILITY INT32 mmSemTimedWait(mmSem_t *sem, INT32 timeout); +MMPA_FUNC_VISIBILITY mmSsize_t mmWritev(mmProcess fd, mmIovSegment *iov, INT32 iovcnt); +MMPA_FUNC_VISIBILITY VOID mmMb(); +MMPA_FUNC_VISIBILITY INT32 mmInetAton(const CHAR *addrStr, mmInAddr *addr); + +MMPA_FUNC_VISIBILITY mmProcess mmOpenFile(const CHAR *fileName, UINT32 access, mmCreateFlag fileFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmReadFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY mmSsize_t mmWriteFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY INT32 mmCloseFile(mmProcess fileId); + +MMPA_FUNC_VISIBILITY mmAtomicType mmSetData(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueInc(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueSub(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmSetData64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueInc64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueSub64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithDetach(mmThread *threadHandle, mmUserBlock_t *funcBlock); + +// The following 3 interfaces are to be deleted +MMPA_FUNC_VISIBILITY INT32 mmCreateNamedPipe(mmPipeHandle pipe[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenNamePipe(mmPipeHandle pipe[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmCloseNamedPipe(mmPipeHandle namedPipe[]); + +MMPA_FUNC_VISIBILITY INT32 mmCreatePipe(mmPipeHandle pipe[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenPipe(mmPipeHandle pipe[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmClosePipe(mmPipeHandle pipe[], UINT32 pipeCount); + +// Poll related interface +MMPA_FUNC_VISIBILITY mmCompletionHandle mmCreateCompletionPort(); +MMPA_FUNC_VISIBILITY VOID mmCloseCompletionPort(mmCompletionHandle handle); +MMPA_FUNC_VISIBILITY INT32 mmPoll(mmPollfd *fds, + INT32 fdCount, + INT32 timeout, + mmCompletionHandle handleIOCP, + pmmPollData polledData, + mmPollBack pollBack); +MMPA_FUNC_VISIBILITY INT32 mmGetErrorCode(); +MMPA_FUNC_VISIBILITY CHAR *mmGetErrorFormatMessage(mmErrorMsg errnum, CHAR *buf, mmSize size); +MMPA_FUNC_VISIBILITY INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone); +MMPA_FUNC_VISIBILITY mmTimespec mmGetTickCount(); +MMPA_FUNC_VISIBILITY INT32 mmGetRealPath(CHAR *path, CHAR *realPath); +MMPA_FUNC_VISIBILITY INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); + +MMPA_FUNC_VISIBILITY INT32 mmDup2(INT32 oldFd, INT32 newFd); + +MMPA_FUNC_VISIBILITY INT32 mmDup(INT32 fd); + +MMPA_FUNC_VISIBILITY INT32 mmUnlink(const CHAR *filename); + +MMPA_FUNC_VISIBILITY INT32 mmChmod(const CHAR *filename, INT32 mode); + +MMPA_FUNC_VISIBILITY INT32 mmFileno(FILE *stream); + +MMPA_FUNC_VISIBILITY INT32 mmScandir(const CHAR *path, mmDirent ***entryList, mmFilter filterFunc, mmSort sort); +MMPA_FUNC_VISIBILITY INT32 mmScandir2(const CHAR *path, mmDirent2 ***entryList, mmFilter2 filterFunc, mmSort2 sort); + +MMPA_FUNC_VISIBILITY VOID mmScandirFree(mmDirent **entryList, INT32 count); +MMPA_FUNC_VISIBILITY VOID mmScandirFree2(mmDirent2 **entryList, INT32 count); + +MMPA_FUNC_VISIBILITY mmMsgid mmMsgCreate(mmKey_t key, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY mmMsgid mmMsgOpen(mmKey_t key, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgSnd(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgRcv(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgClose(mmMsgid msqid); + +MMPA_FUNC_VISIBILITY INT32 mmLocalTimeR(const time_t *timep, struct tm *result); + +MMPA_FUNC_VISIBILITY INT32 mmGetOptErr(); +MMPA_FUNC_VISIBILITY VOID mmSetOptErr(INT32 mmOptErr); +MMPA_FUNC_VISIBILITY INT32 mmGetOptInd(); +MMPA_FUNC_VISIBILITY VOID mmSetOptInd(INT32 mmOptInd); +MMPA_FUNC_VISIBILITY INT32 mmGetOptOpt(); +MMPA_FUNC_VISIBILITY VOID mmSetOpOpt(INT32 mmOptOpt); +MMPA_FUNC_VISIBILITY CHAR *mmGetOptArg(); +MMPA_FUNC_VISIBILITY VOID mmSetOptArg(CHAR *mmOptArg); +MMPA_FUNC_VISIBILITY INT32 mmGetOpt(INT32 argc, char *const *argv, const char *opts); +MMPA_FUNC_VISIBILITY INT32 mmGetOptLong(INT32 argc, + char *const *argv, + const char *opts, + const mmStructOption *longOpts, + INT32 *longIndex); + +MMPA_FUNC_VISIBILITY LONG mmLseek(INT32 fd, INT64 offset, INT32 seekFlag); +MMPA_FUNC_VISIBILITY INT32 mmFtruncate(mmProcess fd, UINT32 length); + +MMPA_FUNC_VISIBILITY INT32 mmTlsCreate(mmThreadKey *key, VOID (*destructor)(VOID *)); +MMPA_FUNC_VISIBILITY INT32 mmTlsSet(mmThreadKey key, const VOID *value); +MMPA_FUNC_VISIBILITY VOID *mmTlsGet(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmTlsDelete(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmGetOsType(); + +MMPA_FUNC_VISIBILITY INT32 mmFsync(mmProcess fd); +MMPA_FUNC_VISIBILITY INT32 mmFsync2(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmChdir(const CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmUmask(INT32 pmode); +MMPA_FUNC_VISIBILITY INT32 mmThreadKill(mmThread id); +MMPA_FUNC_VISIBILITY INT32 mmWaitPid(mmProcess pid, INT32 *status, INT32 options); + +MMPA_FUNC_VISIBILITY INT32 mmGetCwd(CHAR *buffer, INT32 maxLen); +MMPA_FUNC_VISIBILITY INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len); +MMPA_FUNC_VISIBILITY INT32 mmSetEnv(const CHAR *name, const CHAR *value, INT32 overwrite); +MMPA_FUNC_VISIBILITY CHAR *mmStrTokR(CHAR *str, const CHAR *delim, CHAR **saveptr); +MMPA_FUNC_VISIBILITY CHAR *mmDirName(CHAR *path); +MMPA_FUNC_VISIBILITY CHAR *mmBaseName(CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmGetDiskFreeSpace(const char *path, mmDiskSize *diskSize); + +/* + * Function: set the thread name created by mmcreatetask + * Input: pstThreadHandle: thread ID + * name: thread name, the actual length of name must be < MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmSetThreadName(mmThread *threadHandle, const CHAR *name); + +/* + * Function: get thread name + * Input: pstThreadHandle: thread ID + * size: Cache length of thread name + * name:User allocated cache for thread name, Cache length must be >= MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmGetThreadName(mmThread *threadHandle, CHAR *name, INT32 size); +/* + * Function:Set the thread name of the currently executing thread - call inside the thread body + * Input:name:Thread name to be set + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmSetCurrentThreadName(const CHAR *name); +/* + * Function:Get the thread name of the currently executing thread - in body call + * Input:name:The name of the thread to get, and the cache is allocated by the user,size>=MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmGetCurrentThreadName(CHAR *name, INT32 size); +MMPA_FUNC_VISIBILITY INT32 mmGetFileSize(const CHAR *fileName, ULONGLONG *length); +MMPA_FUNC_VISIBILITY INT32 mmIsDir(const CHAR *fileName); +MMPA_FUNC_VISIBILITY INT32 mmGetOsName(CHAR *name, INT32 nameSize); +MMPA_FUNC_VISIBILITY INT32 mmGetOsVersion(CHAR *versionInfo, INT32 versionLength); +MMPA_FUNC_VISIBILITY INT32 mmGetMac(mmMacInfo **list, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmGetMacFree(mmMacInfo *list, INT32 count); +MMPA_FUNC_VISIBILITY INT32 mmGetCpuInfo(mmCpuDesc **cpuInfo, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmCpuInfoFree(mmCpuDesc *cpuInfo, INT32 count); +MMPA_FUNC_VISIBILITY INT32 mmCreateProcess(const CHAR *fileName, + const mmArgvEnv *env, + const char *stdoutRedirectFile, + mmProcess *id); + +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithThreadAttr(mmThread *threadHandle, + const mmUserBlock_t *funcBlock, + const mmThreadAttr *threadAttr); +MMPA_FUNC_VISIBILITY mmFileHandle mmShmOpen(const CHAR *name, INT32 oflag, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmShmUnlink(const CHAR *name); +MMPA_FUNC_VISIBILITY VOID *mmMmap(mmFd_t fd, mmSize_t size, mmOfft_t offset, mmFd_t *extra, INT32 prot, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmMunMap(VOID *data, mmSize_t size, mmFd_t *extra); +#define MMPA_DLL_API + +#ifdef __cplusplus +#if __cplusplus +} +#endif /* __cpluscplus */ +#endif // __cpluscplus + +#endif // MMPA_LINUX_MMPA_LINUX_H_ diff --git a/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_linux.h b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_linux.h new file mode 100644 index 00000000..9df5b9ce --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_linux.h @@ -0,0 +1,98 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMPA_TYPEDEF_LINUX_H +#define MMPA_TYPEDEF_LINUX_H + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus + +#ifndef FALSE +#define FALSE 0 +#endif + +#ifndef TRUE +#define TRUE 1 +#endif + +typedef unsigned char UINT8; +typedef signed char INT8; +typedef unsigned short UINT16; +typedef signed short INT16; +typedef unsigned int UINT32; +typedef signed int INT32; +typedef unsigned long long UINT64; +typedef signed long long INT64; +typedef float FLOAT; +typedef double DOUBLE; +typedef void VOID; +typedef unsigned char UCHAR; +typedef char CHAR; +typedef unsigned short USHORT; +typedef short SHORT; +typedef unsigned int UINT; +typedef int INT; +typedef unsigned long ULONG; +typedef unsigned long long ULONGLONG; + +typedef long LONG; + +#define HANDLE_INVALID_VALUE (-1) +#define MMPA_MEM_MAX_LEN (0x7fffffff) +#define MMPA_PROCESS_ERROR (0x7fffffff) +#define PATH_SIZE 256 +#define MAX_IOVEC_SIZE 32 +#define MMPA_MAX_SLEEP_MILLSECOND 4294967 +#define MAX_PIPE_COUNT 2 +#define MMPA_PIPE_COUNT 2 +#define MMPA_THREADNAME_SIZE 16 +#define MMPA_MIN_OS_NAME_SIZE 64 +#define MMPA_MIN_OS_VERSION_SIZE 128 + +#define MMPA_ONE_THOUSAND 1000 +#define MMPA_ONE_BILLION 1000000000 +#define MMPA_COMPUTER_BEGIN_YEAR 1900 +#define MMPA_ZERO 0 +#define MMPA_MAX_THREAD_PIO 99 +#define MMPA_MIN_THREAD_PIO 1 +#define MMPA_DEFAULT_PIPE_PERMISSION 0777 +#define MMPA_DEFAULT_MSG_TYPE 1 + +#define MMPA_THREAD_SCHED_RR SCHED_RR +#define MMPA_THREAD_SCHED_FIFO SCHED_FIFO +#define MMPA_THREAD_SCHED_OTHER SCHED_OTHER +#define MMPA_THREAD_MIN_STACK_SIZE PTHREAD_STACK_MIN + +#define MM_MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER + +#define MMPA_MAX_NI 19 +#define MMPA_MIN_NI (-20) + +#define EN_OK 0 +#define EN_ERR 1 +#define EN_ERROR (-1) +#define EN_INVALID_PARAM (-2) +#define EN_TIMEOUT (-3) + +#ifdef __cplusplus +#if __cplusplus +} +#endif // __cpluscplus +#endif // __cpluscplus +#endif // MMPA_TYPEDEF_LINUX_H_ diff --git a/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_win.h b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_win.h new file mode 100644 index 00000000..8200bea6 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_typedef_win.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMPA_TYPEDEF_WIN_H +#define MMPA_TYPEDEF_WIN_H + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus + +#ifndef FALSE +#define FALSE 0 +#endif + +#ifndef TRUE +#define TRUE 1 +#endif + +#define EN_OK 0 +#define EN_ERR 1 +#define EN_ERROR (-1) +#define EN_INVALID_PARAM (-2) +#define EN_TIMEOUT (-3) + +#define HANDLE_INVALID_VALUE (-1) +#define INVALID_SOCKET_HANDLE INVALID_SOCKET +#define MMPA_MEM_MAX_LEN (0x7fffffff) +#define MMPA_PROCESS_ERROR (0x7fffffff) + +#define MMPA_ONE_THOUSAND 1000 +#define MMPA_COMPUTER_BEGIN_YEAR 1900 +#define SUMMER_TIME_OR_NOT (-1) +#define MMPA_ZERO 0 +#define MMPA_VALUE_ONE 1 +#define MMPA_SOCKET_MAIN_EDITION 2 +#define MMPA_SOCKET_SECOND_EDITION 0 +#define MMPA_PIPE_BUF_SIZE 1024 +#define MMPA_MAX_SCANDIR_COUNT 1024 +#define MAX_IOVEC_SIZE 32 +#define MMPA_PIPE_COUNT 2 +#define MMPA_THREADNAME_SIZE 16 +#define MMPA_MIN_OS_NAME_SIZE (MAX_COMPUTERNAME_LENGTH + 1) +#define MMPA_MIN_OS_VERSION_SIZE 64 + +#define MMPA_MAX_NI 19 +#define MMPA_MIDDLE_NI 5 +#define MMPA_LOW_NI (-5) +#define MMPA_MIN_NI (-20) +#define MMPA_MAX_FILE 128 + +#define MMPA_MAX_THREAD_PIO 99 +#define MMPA_MIDDLE_THREAD_PIO 66 +#define MMPA_LOW_THREAD_PIO 33 +#define MMPA_MIN_THREAD_PIO 1 + +#define MMPA_THREAD_SCHED_RR 0 +#define MMPA_THREAD_SCHED_FIFO 0 +#define MMPA_THREAD_SCHED_OTHER 0 +#define MMPA_THREAD_MIN_STACK_SIZE 0 + +#define MM_MUTEX_INITIALIZER NULL + +#ifdef __cplusplus +#if __cplusplus +} +#endif // __cpluscplus +#endif // __cpluscplus +#endif // _MMPA_TYPEDEF_WIN_H_ diff --git a/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h new file mode 100644 index 00000000..cecdd4a7 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h @@ -0,0 +1,565 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMPA_WIN_MMPA_WIN_H +#define MMPA_WIN_MMPA_WIN_H +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus +#ifdef MMPA_DLL +#define MMPA_DLL_API __declspec(dllexport) +#else +#define MMPA_DLL_API __declspec(dllimport) +#endif + +#define MMPA_MACINFO_DEFAULT_SIZE 18 +#define MMPA_CPUDESC_DEFAULT_SIZE 64 + +#pragma section(".CRT$XCU", long, read) +#pragma section(".CRT$XPU", long, read) + +typedef HANDLE mmMutex_t; +typedef HANDLE mmThread; +typedef HANDLE mmProcess; +typedef HANDLE mmPollHandle; +typedef HANDLE mmPipeHandle; +typedef HANDLE mmFileHandle; +typedef HANDLE mmCompletionHandle; +typedef HANDLE mmFd_t; +typedef CRITICAL_SECTION mmMutexFC; +typedef CONDITION_VARIABLE mmCond; + +typedef VOID *(*userProcFunc)(VOID *pulArg); +typedef struct { + userProcFunc procFunc; + VOID *pulArg; +} mmUserBlock_t; + +typedef DWORD mmThreadKey; +typedef SYSTEMTIME mmSystemTime_t; + +typedef HANDLE mmSem_t; +typedef SOCKET mmSockHandle; +typedef SRWLOCK mmRWLock_t; +typedef struct sockaddr mmSockAddr; +typedef int mmSocklen_t; +typedef int mmSemTimeout_t; +typedef long mmAtomicType; +typedef long long mmAtomicType64; +typedef DWORD mmExitCode; +typedef DWORD mmErrorMsg; +typedef int mmKey_t; +typedef HANDLE mmMsgid; +typedef long int mmOfft_t; +typedef int mmPid_t; + +typedef INT32 mmSsize_t; +typedef int mmSize; // size +typedef size_t mmSize_t; +typedef VOID mmshmId_ds; +typedef long long MM_LONG; + +typedef enum { + DT_DIR = FILE_ATTRIBUTE_DIRECTORY, +} mmDtype; + +typedef struct { + unsigned char d_type; + char d_name[MAX_PATH]; // file name +} mmDirent; + +typedef struct { + unsigned long d_type; + char d_name[MAX_PATH]; // file name +} mmDirent2; + +typedef int (*mmFilter)(const mmDirent *entry); +typedef int (*mmFilter2)(const mmDirent2 *entry); +typedef int (*mmSort)(const mmDirent **a, const mmDirent **b); +typedef int (*mmSort2)(const mmDirent2 **a, const mmDirent2 **b); + +typedef struct { + VOID *sendBuf; + INT32 sendLen; +} mmIovSegment; +typedef PVOID mmInAddr; + +typedef enum { + pollTypeRead = 1, // pipeline reading + pollTypeRecv, // socket receive + pollTypeIoctl, // ioctl read +} mmPollType; + +typedef struct { + HANDLE completionHandle; + mmPollType overlapType; + OVERLAPPED oa; +} mmComPletionKey, *pmmComPletionKey; + +typedef struct { + VOID *priv; // User defined private content + mmPollHandle bufHandle; // Value of handle corresponding to buf + mmPollType bufType; // Data types polled to + VOID *buf; + UINT32 bufLen; + UINT32 bufRes; +} mmPollData, *pmmPollData; + +typedef VOID (*mmPollBack)(pmmPollData); +typedef struct { + mmPollHandle handle; // The file descriptor or handle of poll is required + mmPollType pollType; // Operation type requiring poll,read or recv or ioctl + INT32 ioctlCode; // IOCTL operation code, dedicated to IOCTL + mmComPletionKey completionKey; // The default value is blank, which will be used in windows to receive the data with + // different handle +} mmPollfd; + +typedef struct { + OVERLAPPED oa; + HANDLE completionHandle; + WSABUF DataBuf; +} PRE_IO_DATA, *PPRE_IO_DATA; + +typedef OVERLAPPED mmOverLap; + +typedef struct { + UINT32 createFlag; + INT32 oaFlag; // Overlap operation is supported if it is not 0 +} mmCreateFlag; + +typedef struct { + VOID *inbuf; + INT32 inbufLen; + VOID *outbuf; + INT32 outbufLen; + mmOverLap *oa; +} mmIoctlBuf; + +typedef struct { + HANDLE timerQueue; + HANDLE timerHandle; +} mmTimerHandle; + +typedef struct { + LONG tv_sec; + LONG tv_usec; +} mmTimeval; + +typedef struct { + INT32 tz_minuteswest; // How many minutes is it different from Greenwich + INT32 tz_dsttime; // DST correction type +} mmTimezone; + +typedef struct { + MM_LONG tv_sec; + MM_LONG tv_nsec; +} mmTimespec; + +typedef mmTimerHandle mmTimer; + +#define mmTLS __declspec(thread) + +typedef struct stat mmStat_t; +typedef struct _stat64 mmStat64_t; +typedef int mmMode_t; + +typedef int MODE; + +typedef struct { + const char *name; + int has_arg; + int *flag; + int val; +} mmStructOption; + +typedef struct { + ULONGLONG totalSize; + ULONGLONG freeSize; + ULONGLONG availSize; +} mmDiskSize; + +typedef struct { + const char *dli_fname; + void *dli_fbase; + const char *dli_sname; + void *dli_saddr; + size_t dli_size; /* ELF only */ + int dli_bind; /* ELF only */ + int dli_type; +} mmDlInfo; + +typedef struct { + char addr[MMPA_MACINFO_DEFAULT_SIZE]; // ex:aa-bb-cc-dd-ee-ff\0 +} mmMacInfo; + +typedef struct { + char arch[MMPA_CPUDESC_DEFAULT_SIZE]; + char manufacturer[MMPA_CPUDESC_DEFAULT_SIZE]; // vendor + char version[MMPA_CPUDESC_DEFAULT_SIZE]; // modelname + INT32 frequency; // cpu frequency + INT32 maxFrequency; // max speed + INT32 ncores; // cpu cores + INT32 nthreads; // cpu thread count + INT32 ncounts; // logical cpu nums +} mmCpuDesc; + +typedef struct { + char **argv; + INT32 argvCount; + char **envp; + INT32 envpCount; +} mmArgvEnv; + +// Windows currently does not support properties other than thread separation properties +typedef struct { + INT32 detachFlag; // Thread detach property: 0 do not detach 1 detach + INT32 priorityFlag; + INT32 priority; + INT32 policyFlag; + INT32 policy; + INT32 stackFlag; + UINT32 stackSize; +} mmThreadAttr; + +typedef VOID (*mmPf)(VOID); + +#define mm_no_argument 0 +#define mm_required_argument 1 +#define mm_optional_argument 2 + +#define M_FILE_RDONLY GENERIC_READ +#define M_FILE_WRONLY GENERIC_WRITE +#define M_FILE_RDWR (GENERIC_READ | GENERIC_WRITE) +#define M_FILE_CREAT OPEN_ALWAYS + +#define M_RDONLY _O_RDONLY +#define M_WRONLY _O_WRONLY +#define M_RDWR _O_RDWR +#define M_IRWXU _O_RDWR +#define M_CREAT _O_CREAT +#define M_BINARY _O_BINARY +#define M_TRUNC _O_TRUNC +#define M_APPEND _O_APPEND + +#define M_IREAD _S_IREAD +#define M_IRUSR _S_IREAD +#define M_IWRITE _S_IWRITE +#define M_IWUSR _S_IWRITE +#define M_IXUSR 0 + +#define M_IN_CREATE FILE_NOTIFY_CHANGE_FILE_NAME | FILE_NOTIFY_CHANGE_DIR_NAME +#define M_IN_CLOSE_WRITE FILE_NOTIFY_CHANGE_LAST_WRITE +#define M_IN_IGNORED FILE_NOTIFY_CHANGE_FILE_NAME | FILE_NOTIFY_CHANGE_DIR_NAME + +#define M_OUT_CREATE 0x00000100 +#define M_OUT_CLOSE_WRITE 0x00000008 +#define M_OUT_IGNORED 0x00008000 +#define M_OUT_ISDIR 0x40000000 + +#define M_MSG_CREAT 1 +#define M_MSG_EXCL 2 +#define M_MSG_NOWAIT 3 + +#define M_WAIT_NOHANG 1 +#define M_WAIT_UNTRACED 2 + +#define M_UMASK_USRREAD _S_IREAD +#define M_UMASK_GRPREAD _S_IREAD +#define M_UMASK_OTHREAD _S_IREAD + +#define M_UMASK_USRWRITE _S_IWRITE +#define M_UMASK_GRPWRITE _S_IWRITE +#define M_UMASK_OTHWRITE _S_IWRITE + +#define M_UMASK_USREXEC 0 +#define M_UMASK_GRPEXEC 0 +#define M_UMASK_OTHEXEC 0 + +#define DT_UNKNOWN 0 +#define DT_FIFO 1 +#define DT_CHR 2 +#define DT_BLK 6 +#define DT_REG 8 +#define DT_LNK 10 +#define DT_SOCK 12 +#define DT_WHT 14 +#define MM_DT_DIR 16 +#define MM_DT_REG 32 + +#define mmConstructor(x) __declspec(allocate(".CRT$XCU")) mmPf con = x +#define mmDestructor(x) __declspec(allocate(".CRT$XPU")) mmPf de = x + +#define MMPA_PRINT_ERROR ((opterr) && (*options != ':')) +#define MMPA_FLAG_PERMUTE 0x01 // permute non-options to the end of argv +#define MMPA_FLAG_ALLARGS 0x02 // treat non-options as args to option "-1" +#define MMPA_FLAG_LONGONLY 0x04 // operate as getopt_long_only +// return values +#define MMPA_BADCH (INT32)'?' +#define MMPA_BADARG ((*options == ':') ? (INT32)':' : (INT32)'?') +#define MMPA_INORDER (INT32)1 + +#define MMPA_NO_ARGUMENT 0 +#define MMPA_REQUIRED_ARGUMENT 1 +#define MMPA_OPTIONAL_ARGUMENT 2 + +#define MMPA_EMSG "" +#define MMPA_MAX_PATH MAX_PATH +#define M_NAME_MAX _MAX_FNAME + +#define M_F_OK 0 +#define M_W_OK 2 +#define M_R_OK 4 + +#define MMPA_STDIN stdin +#define MMPA_STDOUT stdout +#define MMPA_STDERR stderr + +#define MMPA_RTLD_NOW 0 +#define MMPA_RTLD_GLOBAL 0 +#define MMPA_RTLD_LAZY 0 +#define MMPA_RTLD_NODELETE 0 + +#define MMPA_DL_EXT_NAME ".dll" + +#define __attribute__(v) + +MMPA_FUNC_VISIBILITY INT32 mmCreateTask(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmJoinTask(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmMutexInit(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexTryLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexUnLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexDestroy(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondInit(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondLockInit(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondUnLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLockDestroy(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmRWLockInit(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRDLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmWRLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockDestroy(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmCondWait(mmCond *cond, mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondTimedWait(mmCond *cond, mmMutexFC *mutex, UINT32 milliSecond); + +MMPA_FUNC_VISIBILITY INT32 mmCondNotify(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondNotifyAll(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondDestroy(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmGetPid(VOID); +MMPA_FUNC_VISIBILITY INT32 mmGetTid(VOID); +MMPA_FUNC_VISIBILITY INT32 mmGetPidHandle(mmProcess *processHandle); +MMPA_FUNC_VISIBILITY INT32 mmGetLocalTime(mmSystemTime_t *sysTime); +MMPA_FUNC_VISIBILITY INT32 mmGetSystemTime(mmSystemTime_t *sysTime); +MMPA_FUNC_VISIBILITY INT32 mmSemInit(mmSem_t *sem, UINT32 value); +MMPA_FUNC_VISIBILITY INT32 mmSemWait(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemPost(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemDestroy(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmOpen(const CHAR *pathName, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmOpen2(const CHAR *pathName, INT32 flags, MODE mode); +MMPA_FUNC_VISIBILITY FILE *mmPopen(CHAR *command, CHAR *type); +MMPA_FUNC_VISIBILITY INT32 mmClose(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmPclose(FILE *stream); +MMPA_FUNC_VISIBILITY mmSsize_t mmWrite(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSsize_t mmRead(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSockHandle mmSocket(INT32 sockFamily, INT32 type, INT32 protocol); +MMPA_FUNC_VISIBILITY INT32 mmBind(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmListen(mmSockHandle sockFd, INT32 backLog); +MMPA_FUNC_VISIBILITY mmSockHandle mmAccept(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t *addrLen); +MMPA_FUNC_VISIBILITY INT32 mmConnect(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmCloseSocket(mmSockHandle sockFd); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecv(mmSockHandle sockFd, VOID *recvBuf, INT32 recvLen, INT32 recvFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketSend(mmSockHandle sockFd, VOID *sendBuf, INT32 sendLen, INT32 sendFlag); +MMPA_FUNC_VISIBILITY INT32 mmSocketSendTo(mmSockHandle sockFd, + VOID *sendMsg, + INT32 sendLen, + UINT32 sendFlag, + const mmSockAddr* addr, + INT32 tolen); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecvFrom(mmSockHandle sockFd, + VOID *recvBuf, + mmSize recvLen, + UINT32 recvFlag, + mmSockAddr* addr, + mmSocklen_t *FromLen); +MMPA_FUNC_VISIBILITY INT32 mmSAStartup(VOID); +MMPA_FUNC_VISIBILITY INT32 mmSACleanup(VOID); +MMPA_FUNC_VISIBILITY VOID *mmDlopen(const CHAR *fileName, INT mode); +MMPA_FUNC_VISIBILITY INT32 mmDladdr(VOID *addr, mmDlInfo *info); +MMPA_FUNC_VISIBILITY VOID *mmDlsym(VOID *handle, const CHAR *fileName); +MMPA_FUNC_VISIBILITY INT32 mmDlclose(VOID *handle); +MMPA_FUNC_VISIBILITY CHAR *mmDlerror(VOID); +MMPA_FUNC_VISIBILITY INT32 + mmCreateAndSetTimer(mmTimer *timerHandle, mmUserBlock_t *timerBlock, UINT milliSecond, UINT period); +MMPA_FUNC_VISIBILITY INT32 mmDeleteTimer(mmTimer timerHandle); +MMPA_FUNC_VISIBILITY INT32 mmStatGet(const CHAR *path, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmStat64Get(const CHAR *path, mmStat64_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmFStatGet(INT32 fd, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmMkdir(const CHAR *pathName, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmSleep(UINT32 milliSecond); +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithAttr(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmGetProcessPrio(mmProcess pid); +MMPA_FUNC_VISIBILITY INT32 mmSetProcessPrio(mmProcess pid, INT32 processPrio); +MMPA_FUNC_VISIBILITY INT32 mmGetThreadPrio(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmSetThreadPrio(mmThread *threadHandle, INT32 threadPrio); +MMPA_FUNC_VISIBILITY INT32 mmAccess(const CHAR *pathName); +MMPA_FUNC_VISIBILITY INT32 mmAccess2(const CHAR *pathName, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmRmdir(const CHAR *pathName); + +MMPA_FUNC_VISIBILITY INT32 mmIoctl(mmProcess fd, INT32 ioctlCode, mmIoctlBuf *bufPtr); +MMPA_FUNC_VISIBILITY INT32 mmSemTimedWait(mmSem_t *sem, INT32 timeout); +MMPA_FUNC_VISIBILITY mmSsize_t mmWritev(mmSockHandle fd, mmIovSegment *iov, INT32 iovcnt); +MMPA_FUNC_VISIBILITY VOID mmMb(); +MMPA_FUNC_VISIBILITY INT32 mmInetAton(const CHAR *addrStr, mmInAddr *addr); + +MMPA_FUNC_VISIBILITY mmProcess mmOpenFile(const CHAR *fileName, UINT32 access, mmCreateFlag fileFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmReadFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY mmSsize_t mmWriteFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY INT32 mmCloseFile(mmProcess fileId); + +MMPA_FUNC_VISIBILITY mmAtomicType mmSetData(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueInc(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueSub(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmSetData64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueInc64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueSub64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithDetach(mmThread *threadHandle, mmUserBlock_t *funcBlock); + +MMPA_FUNC_VISIBILITY INT32 mmCreateNamedPipe(mmPipeHandle pipe[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenNamePipe(mmPipeHandle pipe[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmCloseNamedPipe(mmPipeHandle namedPipe[]); + +MMPA_FUNC_VISIBILITY INT32 mmCreatePipe(mmPipeHandle pipe[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenPipe(mmPipeHandle pipe[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmClosePipe(mmPipeHandle pipe[], UINT32 pipeCount); + +MMPA_FUNC_VISIBILITY mmCompletionHandle mmCreateCompletionPort(); +MMPA_FUNC_VISIBILITY VOID mmCloseCompletionPort(mmCompletionHandle handle); +MMPA_FUNC_VISIBILITY INT32 mmPoll(mmPollfd *fds, INT32 fdCount, INT32 timeout, mmCompletionHandle handleIOCP, + pmmPollData polledData, mmPollBack pollBack); + +MMPA_FUNC_VISIBILITY INT32 mmGetErrorCode(); +MMPA_FUNC_VISIBILITY CHAR *mmGetErrorFormatMessage(mmErrorMsg errnum, CHAR *buf, mmSize size); +MMPA_FUNC_VISIBILITY INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone); +MMPA_FUNC_VISIBILITY mmTimespec mmGetTickCount(); +MMPA_FUNC_VISIBILITY INT32 mmGetRealPath(CHAR *path, CHAR *realPath); + +MMPA_FUNC_VISIBILITY INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); + +MMPA_FUNC_VISIBILITY INT32 mmDup2(INT32 oldFd, INT32 newFd); +MMPA_FUNC_VISIBILITY INT32 mmDup(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmUnlink(const CHAR *filename); +MMPA_FUNC_VISIBILITY INT32 mmChmod(const CHAR *filename, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmFileno(FILE *stream); +MMPA_FUNC_VISIBILITY INT32 mmScandir(const CHAR *path, mmDirent ***entryList, mmFilter filterFunc, mmSort sort); +MMPA_FUNC_VISIBILITY INT32 mmScandir2(const CHAR *path, mmDirent2 ***entryList, mmFilter2 filterFunc, mmSort2 sort); +MMPA_FUNC_VISIBILITY VOID mmScandirFree(mmDirent **entryList, INT32 count); +MMPA_FUNC_VISIBILITY VOID mmScandirFree2(mmDirent2 **entryList, INT32 count); + +MMPA_FUNC_VISIBILITY mmMsgid mmMsgCreate(mmKey_t key, INT32 msgFlag); +MMPA_FUNC_VISIBILITY mmMsgid mmMsgOpen(mmKey_t key, INT32 msgFlag); +MMPA_FUNC_VISIBILITY INT32 mmMsgRcv(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); +MMPA_FUNC_VISIBILITY INT32 mmMsgSnd(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgClose(mmMsgid msqid); + +MMPA_FUNC_VISIBILITY INT32 mmLocalTimeR(const time_t *timep, struct tm *result); +MMPA_FUNC_VISIBILITY INT32 mmGetOptErr(); +MMPA_FUNC_VISIBILITY VOID mmSetOptErr(INT32 mmOptErr); +MMPA_FUNC_VISIBILITY INT32 mmGetOptInd(); +MMPA_FUNC_VISIBILITY VOID mmSetOptInd(INT32 mmOptInd); +MMPA_FUNC_VISIBILITY INT32 mmGetOptOpt(); +MMPA_FUNC_VISIBILITY VOID mmSetOpOpt(INT32 mmOptOpt); +MMPA_FUNC_VISIBILITY CHAR *mmGetOptArg(); +MMPA_FUNC_VISIBILITY VOID mmSetOptArg(CHAR *mmOptArg); +MMPA_FUNC_VISIBILITY INT32 mmGetOpt(INT32 argc, char *const *argv, const char *opts); +MMPA_FUNC_VISIBILITY INT32 + mmGetOptLong(INT32 argc, CHAR *const *argv, const CHAR *opts, const mmStructOption *longopts, INT32 *longindex); + +MMPA_FUNC_VISIBILITY LONG mmLseek(INT32 fd, INT64 offset, INT32 seekFlag); +MMPA_FUNC_VISIBILITY INT32 mmFtruncate(mmProcess fd, UINT32 length); + +MMPA_FUNC_VISIBILITY INT32 mmTlsCreate(mmThreadKey *key, VOID (*destructor)(VOID *)); +MMPA_FUNC_VISIBILITY INT32 mmTlsSet(mmThreadKey key, const VOID *value); +MMPA_FUNC_VISIBILITY VOID *mmTlsGet(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmTlsDelete(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmGetOsType(); + +MMPA_FUNC_VISIBILITY INT32 mmFsync(mmProcess fd); +MMPA_FUNC_VISIBILITY INT32 mmFsync2(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmChdir(const CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmUmask(INT32 pmode); +MMPA_FUNC_VISIBILITY INT32 mmWaitPid(mmProcess pid, INT32 *status, INT32 options); + +MMPA_FUNC_VISIBILITY INT32 mmGetCwd(CHAR *buffer, INT32 maxLen); +MMPA_FUNC_VISIBILITY CHAR *mmStrTokR(CHAR *str, const CHAR *delim, CHAR **saveptr); + +MMPA_FUNC_VISIBILITY INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len); +MMPA_FUNC_VISIBILITY INT32 mmSetEnv(const CHAR *name, const CHAR *value, INT32 overwrite); +MMPA_FUNC_VISIBILITY CHAR *mmDirName(CHAR *path); +MMPA_FUNC_VISIBILITY CHAR *mmBaseName(CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmGetDiskFreeSpace(const char *path, mmDiskSize *diskSize); + +MMPA_FUNC_VISIBILITY INT32 mmSetThreadName(mmThread *threadHandle, const CHAR *name); +MMPA_FUNC_VISIBILITY INT32 mmGetThreadName(mmThread *threadHandle, CHAR *name, INT32 size); + +/* + * Function: set the thread name of the currently executing thread - internal call of thread, which is not supported + * under Windows temporarily, and is null. + * Input: name: the thread name to be set + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmSetCurrentThreadName(const CHAR *name); + +/* + * Function: Get the thread name of the currently executing thread - thread body call, not supported under windows, null + * implementation. + * Input:name:The name of the thread to get, and the cache is allocated by the user,size>=MMPA_THREADNAME_SIZE. + * The input parameter error returns EN_INVALID_PARAM, the execution success returns + * EN_OK, and the execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmGetCurrentThreadName(CHAR *name, INT32 size); + +MMPA_FUNC_VISIBILITY INT32 mmGetFileSize(const CHAR *fileName, ULONGLONG *length); +MMPA_FUNC_VISIBILITY INT32 mmIsDir(const CHAR *fileName); +MMPA_FUNC_VISIBILITY INT32 mmGetOsName(CHAR *name, INT32 nameSize); +MMPA_FUNC_VISIBILITY INT32 mmGetOsVersion(CHAR *versionInfo, INT32 versionLength); +MMPA_FUNC_VISIBILITY INT32 mmGetMac(mmMacInfo **list, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmGetMacFree(mmMacInfo *list, INT32 count); +MMPA_FUNC_VISIBILITY INT32 mmGetCpuInfo(mmCpuDesc **cpuInfo, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmCpuInfoFree(mmCpuDesc *cpuInfo, INT32 count); +MMPA_FUNC_VISIBILITY INT32 + mmCreateProcess(const CHAR *fileName, const mmArgvEnv *env, const char *stdoutRedirectFile, mmProcess *id); + +MMPA_FUNC_VISIBILITY INT32 + mmCreateTaskWithThreadAttr(mmThread *threadHandle, const mmUserBlock_t *funcBlock, const mmThreadAttr *threadAttr); +MMPA_FUNC_VISIBILITY mmFileHandle mmShmOpen(const CHAR *name, INT32 oflag, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmShmUnlink(const CHAR *name); +MMPA_FUNC_VISIBILITY VOID *mmMmap(mmFd_t fd, mmSize_t size, mmOfft_t offset, mmFd_t *extra, INT32 prot, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmMunMap(VOID *data, mmSize_t size, mmFd_t *extra); +#ifdef __cplusplus +#if __cplusplus +} +#endif /* __cpluscplus */ +#endif // __cpluscplus + +#endif // MMPA_WIN_MMPA_WIN_H_ diff --git a/metadef/third_party/fwkacllib/inc/ops/array_ops.h b/metadef/third_party/fwkacllib/inc/ops/array_ops.h new file mode 100644 index 00000000..691b51f6 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/ops/array_ops.h @@ -0,0 +1,1158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file array_ops.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ +#define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { + +/** +*@brief Applies lower_bound(sorted_search_values, values) along each row. \n + +*@par Inputs: +*The input sorted_x and values can be one-dimensional vector. Inputs include: +* @li sorted_x:A `Tensor`. 2-D Tensor where each row is ordered. +* @li values:A `Tensor`. Must have the same type as `sorted_x`. \n + +*@par Attributes: +*@li out_type:An optional `DType` from: `int32, int64`. +Defaults to `int32`. \n + +*@par Outputs: +*y: A `Tensor` of type `out_type`. \n + +*@attention Constraints: +*The implementation for LowerBound on Ascend uses AI CPU, with bad performance. \n + +*@par Quantization supported or not +*Not supported +*@par Quantized inference supported or not +*Supported +*@par L2 convergence supported or not +*@par Multiple batches supported or not \n + +*@par Third-party framework compatibility +*Compatible with tensorflow Operator LowerBound. +*/ + +REG_OP(LowerBound) + .INPUT(sorted_x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, \ + DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .INPUT(values, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, \ + DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(out_type, Type, DT_INT32) + .OP_END_FACTORY_REG(LowerBound) + +/** +*@brief Reverses variable length slices. \n + +*@par Inputs: +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. +* @li x: A Tensor. The input to reverse. +* @li seq_lengths: A 1D Tensor of type int32 or int64. \n + +*@par Attributes: +*@li seq_dim: An optional int. The dimension along which +reversal is performed. +*@li batch_dim: An optional int. Defaults to "0". The dimension along which +reversal is performed. \n + +*@par Outputs: +*y: A rank k tensor. Has the same shape as input. The extracted banded tensor. \n + +*@attention Constraints: +*ReverseSequence runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ReverseSequence. +*/ + +REG_OP(ReverseSequence) + .INPUT(x, + TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(seq_lengths, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, + TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128})) + .REQUIRED_ATTR(seq_dim, Int) + .ATTR(batch_dim, Int, 0) + .OP_END_FACTORY_REG(ReverseSequence) + +/** +*@brief Copies a tensor setting everything outside a central band in each innermost matrix. \n + +*@par Inputs: +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. +* @li x: A rank k tensor. +* @li num_lower: A 0D tensor. Number of superdiagonals to keep. If negative, +keeps entire upper triangle. +* @li num_upper: A 0D tensor. Number of superdiagonals to keep. If negative, +keeps entire upper triangle. \n + +*@par Outputs: +*y: A rank k tensor. Has the same shape as input. The extracted banded tensor. \n + +*@attention Constraints: +*MatrixBandPart runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator MatrixBandPart. +*/ + +REG_OP(MatrixBandPart) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, \ + DT_INT16, DT_UINT16, DT_INT32, DT_INT64, + DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, + DT_COMPLEX64, DT_COMPLEX128 })) + .INPUT(num_lower, TensorType({ DT_INT32, DT_INT64 })) + .INPUT(num_upper, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(MatrixBandPart) + +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*x: 1D tensor. +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". +Defaults to "int32". \n + +*@par Outputs: +*@li y: A Tensor. Has the same type as "x". +*@li idx: A Tensor of type "out_idx". +*@li count: A Tensor of type "out_idx". \n + +*@attention Constraints: +*UniqueWithCounts runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator UniqueWithCounts. +*/ + +REG_OP(UniqueWithCounts) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_STRING })) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_STRING })) + .OUTPUT(idx, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(count, TensorType({ DT_INT32, DT_INT64 })) + .REQUIRED_ATTR(out_idx, Type) + .OP_END_FACTORY_REG(UniqueWithCounts) + +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*x: 1D tensor. +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n + +*@par Outputs: +*@li y: "x" in the unique output "y". +*@li idx: A tensor the same size as "x". The index of each value of "x". \n + +*@attention Constraints: +*Unique runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Unique. +*/ + +REG_OP(Unique) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) + .ATTR(out_idx, Type, DT_INT32) + .OP_END_FACTORY_REG(Unique) + +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. +*Including: +* @li x: 1D tensor. +* @li axis: A Tensor of type int32. Defaults to "None". \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". +Defaults to "int32". \n + +*@par Outputs: +*@li y: "x" in the unique output "y". +*@li idx: A tensor the same size as "x". The index of each value of "x". \n + +*@attention Constraints: +*UniqueExt2 runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator UniqueExt2. +*/ + +REG_OP(UniqueExt2) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .INPUT(axis, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) + .ATTR(out_idx, Type, DT_INT32) + .OP_END_FACTORY_REG(UniqueExt2) + +/** +*@brief Computes the inverse permutation of a tensor. \n + +*@par Inputs: +*x: A k-dimensional tensor. \n + +*@par Outputs: +*y: A 1D tensor. \n + +*@attention Constraints: +*InvertPermutation runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator InvertPermutation. +*/ + +REG_OP(InvertPermutation) + .INPUT(x, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(InvertPermutation) + +/** +*@brief Checks a tensor for NaN and Inf values. \n + +*@par Inputs: +*x: A k-dimensional tensor. \n + +*@par Attributes: +*message: Prefix of the error message. \n + +*@par Outputs: +*y: The output tensor. \n + +*@attention Constraints: +*CheckNumerics runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator CheckNumerics. +*/ + +REG_OP(CheckNumerics) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(message, String) + .OP_END_FACTORY_REG(CheckNumerics) + +/** +*@brief Converts an array of flat indices into a tuple of coordinate arrays. \n + +*@par Inputs: +*Input "indices" is a 0D or 1D tensor. Input "dims" is a 1D tensor. +* @li indices: A 0D or 1D int Tensor whose elements are indices into +the flattened version of an array of dimensions "dims". +* @li dims: A 1D int Tensor of the same type as "indices". +*The shape of the array to use for unraveling indices. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "indices". \n + +*@attention Constraints: +*UnravelIndex runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator UnravelIndex. +*/ + +REG_OP(UnravelIndex) + .INPUT(indices, TensorType({DT_INT32, DT_INT64})) + .INPUT(dims, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(UnravelIndex) + +/** +*@brief Applies upper_bound(sorted_search_values, values) along each row. \n + +*@par Inputs: +*Inputs "sorted_x" and "values" are 2D tensors. +* @li sorted_x: A 2D Tensor where each row is ordered. +* @li values: A 2D Tensor with the same numbers of rows as "sorted_x. \n + +*@par Attributes: +*out_type: sets the optional out_type attribute to value. \n + +*@par Outputs: +*y: A Tensor with the same shape as "values". \n + +*@attention Constraints: +*UpperBound runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator UpperBound. +*/ + +REG_OP(UpperBound) + .INPUT(sorted_x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .INPUT(values, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .REQUIRED_ATTR(out_type, Type) + .OP_END_FACTORY_REG(UpperBound) + +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*Inputs "x" and "axis" are 1D vectors. +* @li x: A 1D tensor. +* @li axis: A 1D tensor. \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". +Defaults to "int32". \n + +*@par Outputs: +*@li y: "x" in the unique output "y". +*@li idx: A tensor the same size as "x". The index of each value of "x". +*@li count: A tensor the same size as "x". The index of each value of "x". \n + +*@attention Constraints: +*UniqueWithCountsExt2 runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator UniqueWithCountsExt2. +*/ + +REG_OP(UniqueWithCountsExt2) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_STRING })) + .INPUT(axis, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_STRING })) + .OUTPUT(idx, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(count, TensorType({ DT_INT32, DT_INT64 })) + .REQUIRED_ATTR(out_idx, Type) + .OP_END_FACTORY_REG(UniqueWithCountsExt2) + +/** +*@brief Fills the tensor with the mirror value. \n + +*@par Inputs: +*Inputs "x" and "paddings" are 1D scalars. +* @li x: The tensor to be padded. +* @li paddings: A two-column matrix specifying the padding sizes. +The number of rows Has the same rank as "x". \n + +*@par Attributes: +*mode: Either "REFLECT" or "SYMMETRIC". In reflect mode the padded regions +do not include the borders, while in symmetric mode the padded regions +do include the borders. \n + +*@par Outputs: +*y: The padded tensor. \n + +*@attention Constraints: +*MirrorPad runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator MirrorPad. +*/ + +REG_OP(MirrorPad) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, \ + DT_COMPLEX64, DT_COMPLEX128 })) + .INPUT(paddings, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, \ + DT_COMPLEX64, DT_COMPLEX128 })) + .REQUIRED_ATTR(mode, String) + .OP_END_FACTORY_REG(MirrorPad) + +/** +*@brief Calculates the difference between two numbers or a list of strings. \n + +*@par Inputs: +*Inputs "x" and "y" are 1D vectors. +* @li x: A Tensor. 1D. Values to keep. +* @li y: A Tensor. Must have the same type as x. 1D. Values to remove. \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n + +*@par Outputs: +*@li out: A Tensor. Has the same type as "x". +*@li idx: A Tensor of type "out_idx". \n + +*@attention Constraints: +*ListDiff runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ListDiff. +*/ + +REG_OP(ListDiff) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_INT16, DT_UINT16, DT_INT32, DT_INT64})) + .INPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_INT16, DT_UINT16, DT_INT32, DT_INT64})) + .OUTPUT(out, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_INT16, DT_UINT16, DT_INT32, DT_INT64})) + .OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) + .ATTR(out_idx, Type, DT_INT32) + .OP_END_FACTORY_REG(ListDiff) + +/** +*@brief Create an empty tensor, using the shape and dtype specified in attributes. \n + +*@par Attributes: +*@li dtype: Specify the data type of the empty tensor. +*@li shape: Specify the shape of the empty tensor. \n + +*@par Outputs: +*y: The empty constant tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator _ParallelConcatStart. +*/ +REG_OP(_ParallelConcatStart) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Type, DT_INT32) + .ATTR(shape, ListInt, {}) + .OP_END_FACTORY_REG(_ParallelConcatStart) + +/** +*@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. +Operator Const has the same definition as operator Constant. \n + +*@par Attributes: +*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n + +*@par Outputs: +*y: A constant tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Const. +*/ +REG_OP(Const) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Const) + +/** +*@brief Creates a constant tensor for training. \n + +*@par Attributes: +*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n + +*@par Outputs: +*y: The constant tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Const. +*/ +REG_OP(Constant) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Constant) + +/** +*@brief Returns a copy of the input tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Snapshot. +*/ +REG_OP(Snapshot) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Snapshot) + +/** +*@brief Gives a guarantee to the runtime that the input tensor is a constant. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator GuaranteeConst. +*/ +REG_OP(GuaranteeConst) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(GuaranteeConst) + +/** +*@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n + +*@par Inputs: +*@li x1: A tensor of type int32 or int64. A shape. +*@li x2: A tensor of the same type as "x1". The other shape. \n + +*@par Outputs: +*y: A tensor. The broadcasted shape. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BroadcastArgs. +*/ +REG_OP(BroadcastArgs) + .INPUT(x1, TensorType({DT_INT32, DT_INT64})) + .INPUT(x2, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(BroadcastArgs) + +/** +*@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*message: Will be printed in the error at the attempt to request a gradient. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PreventGradient. +*/ +REG_OP(PreventGradient) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(message, String, "") + .OP_END_FACTORY_REG(PreventGradient) + +/** +*@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n + +*@par Inputs: +*@li x1: A tensor of type int32 or int64. +*@li x2: A tensor of type int32 or int64. +"x2" has the same type as "x1". \n + +*@par Outputs: +*@li y1: A tensor. Reduction indices of "x1". +*@li y2: A tensor. Reduction indices of "x2". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BroadcastGradientArgs. +*/ +REG_OP(BroadcastGradientArgs) + .INPUT(x1, TensorType({DT_INT32, DT_INT64})) + .INPUT(x2, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y1, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y2, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(BroadcastGradientArgs) + +/** +*@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped. + + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator StopGradient. +*/ +REG_OP(StopGradient) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(StopGradient) + +/** +*@brief Return a tensor with the same shape and contents as input. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Identity. +*/ +REG_OP(Identity) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Identity) + +/** +*@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n + +*@par Inputs: +*x: A list of input tensors. It's a dynamic input \n + +*@par Outputs: +*y: A list of Tensor objects, with the same length as the input tensor list. +It's a dynamic output. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator IdentityN. +*/ +REG_OP(IdentityN) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(IdentityN) + +/** +*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: A tensor. +*@li axis: The dimension index at which to expand. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ExpandDims. +*/ +REG_OP(ExpandDims) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .INPUT(axis, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(ExpandDims) + +/** +*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: Original tensor. +*@li axis: List of ints. \n + +*@par Outputs: +*y: Reshape tensor with same data as input. \n + +*@par Third-party framework compatibility +*Compatible with the Onnx operator Unsqueeze. +*/ + +REG_OP(Unsqueeze) + .INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) + .ATTR(axes, ListInt, {}) + .OP_END_FACTORY_REG(Unsqueeze) + +/** +*@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: A tensor. +*@li shape: A tensor. Defines the shape of the output tensor. \n + +*@par Attributes: +*@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0". +*@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n + +*@par Outputs: +*y: A tensor. \n + +*@par Attention: +*This operator cannot be directly called by the acllopExecute API. \n + +*@par Third-party framework compatibility +*@li Compatible with the TensorFlow operator Reshape. +*@li Compatible with the Caffe operator Reshape. +*/ +REG_OP(Reshape) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .INPUT(shape, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(axis, Int, 0) + .ATTR(num_axes, Int, -1) + .OP_END_FACTORY_REG(Reshape) + +/** +*@brief Removes dimensions of size 1 from the shape of a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Squeeze. +*/ +REG_OP(Squeeze) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(axis, ListInt, {}) + .OP_END_FACTORY_REG(Squeeze) + +/** +*@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. The rank of input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Rank. +*/ +REG_OP(Rank) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(Rank) + +/** +*@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n + +*@par Outputs: +*y: A tensor. The size of the input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Size. +*/ +REG_OP(Size) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32,DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(Size) + +/** +*@brief Input data for other operators. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*index: Index of the input tensor.The data type must be int32 or int64. +Assume that net has three data nodes, one should be set 0, another should +be set 1, and the left should be set 2. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the Caffe operator Data. +*/ +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +/** +*@brief Inserts a placeholder for a tensor that will be always fed. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*@li peerIndex: An integer type. The index of the corresponding "end" node connected to. +*@li parentId: A string, used to check if the nodes are from the saved parent node. +*@li parentOpType: A string. Op type of the original node. +*@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n + +*@par Outputs: +*y: The created placeholder tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PlaceHolder. +*/ +REG_OP(PlaceHolder) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to + .ATTR(parentId, String, "") // check if these node are from save parent node + .ATTR(parentOpType, String, "") // op type of original node + .ATTR(anchorIndex, Int, 0) // check if these node are from save anchor + .OP_END_FACTORY_REG(PlaceHolder) + +/** +*@brief Inserts a placeholder with default value for a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*@li dtype: data type of tensor. +*@li shape: tensor shape. \n + +*@par Outputs: +*y: The created placeholder tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PlaceholderWithDefault. +*/ +REG_OP(PlaceholderWithDefault) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .REQUIRED_ATTR(shape, ListInt) + .OP_END_FACTORY_REG(PlaceholderWithDefault) + +/** +*@brief Reads and returns the value of the input variable tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ReadVariableOp. +*/ +REG_OP(ReadVariableOp) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(ReadVariableOp) + +/** +*@brief Mark outputs of one sub graph which partitioned by engine type. + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Attributes: +*@li peerIndex: The index of the corresponding 'placeholder' node it's connected to. +*@li parentOpType: Op type of original node. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(End) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(peerIndex, Int, 0) + .ATTR(parentOpType, String, "") + .OP_END_FACTORY_REG(End) + +/** +*@brief Operations for writing summary data, for use in analysis and visualization. + +*@par Inputs: +* One input: +*x: Collections of summary data. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(Summary) + .INPUT(x, TensorType::ALL()) + .OP_END_FACTORY_REG(Summary) + +/** +*@brief Returns the shape of a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n + +*@par Outputs: +*y: A tensor. The shape of the input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Size. +*/ +REG_OP(Shape) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(Shape) + +/** +*@brief Returns shape of tensors. \n + +*@par Inputs: +*x: A list of input tensors. It's a dynamic input. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n + +*@par Outputs: +*y: A list of tensors with the same length as the input list of tensors. +It's a dynamic output. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ShapeN. +*/ +REG_OP(ShapeN) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(ShapeN) + +/** +*@brief Creates a tensor with the given "shape" and "dtype". \n + +*@par Inputs: +*shape: The shape of the output tensor. \n + +*@par Attributes: +*@li dtype: Optional. The data type of the output tensor. Defaults to "int32". +*@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Empty. +*/ +REG_OP(Empty) + .INPUT(shape, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Int, DT_INT32) + .ATTR(init, Bool, 0) + .OP_END_FACTORY_REG(Empty) + +/** +*@brief Gradient op for MirrorPad op. Folds a mirror-padded tensor. \n + +*@par Inputs: +*Inputs "x" and "y" are 1D vectors. +* @li x: A Tensor. The input tensor to be folded. +* @li paddings: A Tensor of type int32 or int64. A two-column matrix +specifying the padding sizes. \n + +*@par Attributes: +*mode: A string from: "REFLECT", "SYMMETRIC". The mode used in the MirrorPad op. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@attention Constraints: +*MirrorPadGrad runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator MirrorPadGrad. +*/ + +REG_OP(MirrorPadGrad) + .INPUT(x, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, \ + DT_COMPLEX64, DT_COMPLEX128 })) + .INPUT(paddings, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({ DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE, \ + DT_COMPLEX64, DT_COMPLEX128 })) + .REQUIRED_ATTR(mode, String) + .OP_END_FACTORY_REG(MirrorPadGrad) + +/** +*@brief Returns locations of nonzero / true values in a tensor. \n + +*@par Inputs: +*Including: +*x: A Tensor. Must be one of the following types: +DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, +DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n + +*@par Outputs: +*y: A Tensor of type DT_INT64. \n + +*@attention Constraints: +*Where runs on the Ascend AI CPU, which delivers poor performance.\n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Where. +*/ + +REG_OP(Where) + .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_INT64})) + .OP_END_FACTORY_REG(Where) + +/** +*@brief Derived from the Caffe operator Split that splits an input blob to +* multiple output blobs for feeding a blob into multiple output layers. +*The Split node is removed from the graph after the split operation is completed. \n + +*@par Inputs: +*x: A Tensor. Must be one of the following types: +fp16, fp32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n + +*@par Attributes: +*@li N: A required int. The parameter will get the number of dynamic outputs. +*/ +REG_OP(Copy) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64})) + .REQUIRED_ATTR(N, Int) + .OP_END_FACTORY_REG(Copy); + +/** +*@brief Generates fingerprint values. \n + +*@par Inputs: +*@li data: Must have rank 1 or higher. +*@li method: Fingerprint method used by this op. Currently available method is +`farmhash::fingerprint64`. \n + +*@par Outputs: +y: A two-dimensional `Tensor` of type `tf.uint8`. The first dimension equals to +`data`'s first dimension, and the second dimension size depends on the +fingerprint algorithm. \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow Fingerprint operator. +*/ + +REG_OP(Fingerprint) + .INPUT(data, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) + .INPUT(method, TensorType({DT_STRING})) + .OUTPUT(y, TensorType({DT_UINT8})) + .OP_END_FACTORY_REG(Fingerprint) + +/** +*@brief Change the shape of output according to the attr outShape +* + +*@par Inputs: +*x: A Tensor. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n + +*@par Attributes: +*outShape: The shape of output will be inferred according to the attribute +*/ +REG_OP(TransShape) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(outShape,ListInt ,{}) + .OP_END_FACTORY_REG(TransShape); + +/** +*@brief Computes the (possibly normalized) Levenshtein Edit Distance. \n + +*@par Inputs: +*@li hypothesis_indices: The indices of the hypothesis list SparseTensor. +This is an N x R int64 matrix. +*@li hypothesis_shape: The values of the hypothesis list SparseTensor. +This is an N-length vector. +*@li hypothesis_shape: The shape of the hypothesis list SparseTensor. +This is an R-length vector. +*@li truth_indices: The indices of the truth list SparseTensor. +This is an M x R int64 matrix. +*@li truth_shape: The values of the truth list SparseTensor. +This is an M-length vector. +*@li truth_shape: The shape of the truth list SparseTensor. +This is an R-length vector + +*@par Attributes: +*@li normalize: boolean (if true, edit distances are normalized by length of truth). \n + +*@par Outputs: +*@li output: A dense float tensor with rank R - 1. \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow EditDistance operator. +*/ +REG_OP(EditDistance) + .INPUT(hypothesis_indices, TensorType({DT_INT64})) + .INPUT(hypothesis_values, TensorType::BasicType()) + .INPUT(hypothesis_shape, TensorType({DT_INT64})) + .INPUT(truth_indices, TensorType({DT_INT64})) + .INPUT(truth_values, TensorType::BasicType()) + .INPUT(truth_shape, TensorType({DT_INT64})) + .ATTR(normalize, Bool, true) + .OUTPUT(output, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(EditDistance) + +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ diff --git a/metadef/third_party/fwkacllib/inc/runtime/base.h b/metadef/third_party/fwkacllib/inc/runtime/base.h new file mode 100644 index 00000000..b9b2cbe5 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/base.h @@ -0,0 +1,340 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_BASE_H__ +#define __CCE_RUNTIME_BASE_H__ + +#include +#include "toolchain/prof_callback.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +// If you need export the function of this library in Win32 dll, use __declspec(dllexport) +#ifndef RTS_API +#ifdef RTS_DLL_EXPORT +#define RTS_API __declspec(dllexport) +#else +#define RTS_API +#endif +#endif + +typedef int32_t rtError_t; +static const int32_t RT_ERROR_NONE = 0; // success + +/** + * @ingroup dvrt_base + * @brief runtime exception numbers. + */ +typedef enum tagRtExceptionType { + RT_EXCEPTION_NONE = 0, + RT_EXCEPTION_TS_DOWN = 1, + RT_EXCEPTION_TASK_TIMEOUT = 2, + RT_EXCEPTION_TASK_FAILURE = 3, + RT_EXCEPTION_DEV_RUNNING_DOWN = 4, + RT_EXCEPTION_STREAM_ID_FREE_FAILED = 5 +} rtExceptionType; + +/** + * @ingroup dvrt_base + * @brief Switch type. + */ +typedef enum tagRtCondition { + RT_EQUAL = 0, + RT_NOT_EQUAL, + RT_GREATER, + RT_GREATER_OR_EQUAL, + RT_LESS, + RT_LESS_OR_EQUAL +} rtCondition_t; + +/** + * @ingroup dvrt_base + * @brief Data Type of Extensible Switch Task. + */ +typedef enum tagRtSwitchDataType { + RT_SWITCH_INT32 = 0, + RT_SWITCH_INT64 = 1, +} rtSwitchDataType_t; + +typedef enum tagRtStreamFlagType { + RT_HEAD_STREAM = 0, // first stream + RT_INVALID_FLAG = 0xFFFFFFFF, +} rtStreamFlagType_t; + +typedef enum tagRtLimitType { + RT_LIMIT_TYPE_LOW_POWER_TIMEOUT = 0, // timeout for power down , ms +} rtLimitType_t; + +typedef struct rtExceptionInfo { + uint32_t taskid; + uint32_t streamid; + uint32_t tid; + uint32_t deviceid; +} rtExceptionInfo; + +typedef struct rtTaskFailInfo { + uint32_t taskid; + uint32_t streamid; + uint32_t tid; + uint32_t deviceid; + uint32_t retcode; +} rtTaskFailInfo; + +typedef void (*rtErrorCallback)(rtExceptionType); + +typedef void (*rtTaskFailCallback)(rtExceptionInfo *exceptionInfo); + +typedef void (*rtTaskFailCallbackByModule)(rtTaskFailInfo *exceptionInfo); + +typedef void (*rtDeviceStateCallback)(uint32_t devId, bool isOpen); + +/** + * @ingroup dvrt_base + * @brief stream handle. + */ +typedef void *rtStream_t; + +/** + * @ingroup dvrt_base + * @brief runtime event handle. + */ +typedef void *rtEvent_t; + +/** + * @ingroup dvrt_base + * @brief label handle. + */ +typedef void *rtLabel_t; + +/** + * @ingroup profiling_base + * @brief runtime handle. + */ +RTS_API rtError_t rtSetProfDirEx(const char *profDir, const char *address, const char *jobCtx); + +/** + * @ingroup profiling_base + * @brief init profiler object. + */ +RTS_API rtError_t rtProfilerInit(const char *profDir, const char *address, const char *jobCtx); + +/** + * @ingroup profiling_base + * @brief config rts profiler. + */ +RTS_API rtError_t rtProfilerConfig(uint16_t type); + +/** + * @ingroup profiling_base + * @brief start rts profiler. + */ +RTS_API rtError_t rtProfilerStart(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList); + +/** + * @ingroup profiling_base + * @brief stop rts profiler. + */ +RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* deviceList); + +/** + * @ingroup profiling_base + * @brief ts send keypoint profiler log. + */ +RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream); + +/** + * @ingroup profiling_base + * @brief ts set profiling reporter callback. + */ +RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback); + +/** + * @ingroup dvrt_base + * @brief Returns the last error from a runtime call. + */ +RTS_API rtError_t rtGetLastError(); + +/** + * @ingroup dvrt_base + * @brief Returns the last error from a runtime call. + */ +RTS_API rtError_t rtPeekAtLastError(); + +/** + * @ingroup dvrt_base + * @brief register callback for error code + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetExceptCallback(rtErrorCallback callback); + +/** + * @ingroup dvrt_base + * @brief register callback for task fail + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback); + +/** + * @ingroup dvrt_base + * @brief register callback for deviceid + * @param [in] uniName unique register name, can't be null + * @param [in] callback Device state callback function + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCallback callback); + +/** + * @ingroup dvrt_base + * @brief register callback for fail task + * @param [in] uniName unique register name, can't be null + * @param [in] callback fail task callback function + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallbackByModule callback); + +/** + * @ingroup dvrt_base + * @brief notify handle. + */ +typedef void *rtNotify_t; + +/** + * @ingroup dvrt_base + * @brief create label instance + * @param [out] label created label + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreate(rtLabel_t *label); + +/** + * @ingroup dvrt_base + * @brief set label and stream instance + * @param [in] label set label + * @param [in] stream set stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelSet(rtLabel_t label, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief destroy label instance + * @param [in] label label to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelDestroy(rtLabel_t label); + +/** + * @ingroup dvrt_base + * @brief label switch instance + * @param [in] ptr address to get value compared + * @param [in] condition + * @param [in] value to compare + * @param [in] true_label goto label + * @param [in] stream to submit label_switch task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelSwitch(void *ptr, rtCondition_t condition, uint32_t value, rtLabel_t trueLabel, + rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief goto label instance + * @param [in] label goto label + * @param [in] stream to submit label_goto task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelGoto(rtLabel_t label, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief name label instance + * @param [in] label instance + * @param [in] name label name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNameLabel(rtLabel_t label, const char *name); + +/** + * @ingroup dvrt_base + * @brief label switch by index + * @param [in] ptr index value ptr + * @param [in] max index max value + * @param [in] labelInfoPtr label content info ptr + * @param [in] stream set stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief stream goto label + * @param [in] label goto label + * @param [in] stream stream to submit label_goto task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [in] label model label list + * @param [in] labelNumber label number + * @param [in] dst device ptr + * @param [in] dstMax dst size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [out] label created label handle + * @param [in] stream label bind stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief get current thread last stream id and task id + * @param [out] stream id and task id + * @param [in] null + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for input null ptr + */ +RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_BASE_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/config.h b/metadef/third_party/fwkacllib/inc/runtime/config.h new file mode 100644 index 00000000..12a407d7 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/config.h @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_CONFIG_H__ +#define __CCE_RUNTIME_CONFIG_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#define PLAT_COMBINE(arch, chip, ver) ((arch << 16) | (chip << 8) | (ver)) +#define PLAT_GET_ARCH(type) ((type >> 16) & 0xffff) +#define PLAT_GET_CHIP(type) ((type >> 8) & 0xff) +#define PLAT_GET_VER(type) (type & 0xff) + +typedef enum tagRtArchType { + ARCH_BEGIN = 0, + ARCH_V100 = ARCH_BEGIN, + ARCH_V200, + ARCH_END, +} rtArchType_t; + +typedef enum tagRtChipType { + CHIP_BEGIN = 0, + CHIP_MINI = CHIP_BEGIN, + CHIP_CLOUD, + CHIP_MDC, + CHIP_LHISI, + CHIP_DC, + CHIP_END, +} rtChipType_t; + +typedef enum tagRtVersion { + VER_BEGIN = 0, + VER_NA = VER_BEGIN, + VER_ES, + VER_CS, + VER_END, +} rtVersion_t; + +/* match rtChipType_t */ +typedef enum tagRtPlatformType { + PLATFORM_BEGIN = 0, + PLATFORM_MINI_V1 = PLATFORM_BEGIN, + PLATFORM_CLOUD_V1, + PLATFORM_MINI_V2, + PLATFORM_LHISI_ES, + PLATFORM_LHISI_CS, + PLATFORM_DC, + PLATFORM_END, +} rtPlatformType_t; + +typedef enum tagRtCubeFracMKNFp16 { + RT_CUBE_MKN_FP16_2_16_16 = 0, + RT_CUBE_MKN_FP16_4_16_16, + RT_CUBE_MKN_FP16_16_16_16, + RT_CUBE_MKN_FP16_Default, +} rtCubeFracMKNFp16_t; + +typedef enum tagRtCubeFracMKNInt8 { + RT_CUBE_MKN_INT8_2_32_16 = 0, + RT_CUBE_MKN_INT8_4_32_4, + RT_CUBE_MKN_INT8_4_32_16, + RT_CUBE_MKN_INT8_16_32_16, + RT_CUBE_MKN_INT8_Default, +} rtCubeFracMKNInt8_t; + +typedef enum tagRtVecFracVmulMKNFp16 { + RT_VEC_VMUL_MKN_FP16_1_16_16 = 0, + RT_VEC_VMUL_MKN_FP16_Default, +} rtVecFracVmulMKNFp16_t; + +typedef enum tagRtVecFracVmulMKNInt8 { + RT_VEC_VMUL_MKN_INT8_1_32_16 = 0, + RT_VEC_VMUL_MKN_INT8_Default, +} rtVecFracVmulMKNInt8_t; + +typedef struct tagRtAiCoreSpec { + uint32_t cubeFreq; + uint32_t cubeMSize; + uint32_t cubeKSize; + uint32_t cubeNSize; + rtCubeFracMKNFp16_t cubeFracMKNFp16; + rtCubeFracMKNInt8_t cubeFracMKNInt8; + rtVecFracVmulMKNFp16_t vecFracVmulMKNFp16; + rtVecFracVmulMKNInt8_t vecFracVmulMKNInt8; +} rtAiCoreSpec_t; + +typedef struct tagRtAiCoreRatesPara { + uint32_t ddrRate; + uint32_t l2Rate; + uint32_t l2ReadRate; + uint32_t l2WriteRate; + uint32_t l1ToL0ARate; + uint32_t l1ToL0BRate; + uint32_t l0CToUBRate; + uint32_t ubToL2; + uint32_t ubToDDR; + uint32_t ubToL1; +} rtAiCoreMemoryRates_t; + +typedef struct tagRtMemoryConfig { + uint32_t flowtableSize; + uint32_t compilerSize; +} rtMemoryConfig_t; + +typedef struct tagRtPlatformConfig { uint32_t platformConfig; } rtPlatformConfig_t; + +/** + * @ingroup + * @brief get AI core count + * @param [in] aiCoreCnt + * @return aiCoreCnt + */ +RTS_API rtError_t rtGetAiCoreCount(uint32_t *aiCoreCnt); + +/** + * @ingroup + * @brief get AI cpu count + * @param [in] aiCpuCnt + * @return aiCpuCnt + */ +RTS_API rtError_t rtGetAiCpuCount(uint32_t *aiCpuCnt); + +/** + * @ingroup + * @brief get AI core frequency + * @param [in] aiCoreSpec + * @return aiCoreSpec + */ +RTS_API rtError_t rtGetAiCoreSpec(rtAiCoreSpec_t *aiCoreSpec); + +/** + * @ingroup + * @brief AI get core band Info + * @param [in] aiCoreMemoryRates + * @return aiCoreMemoryRates + */ +RTS_API rtError_t rtGetAiCoreMemoryRates(rtAiCoreMemoryRates_t *aiCoreMemoryRates); + +/** + * @ingroup + * @brief AI get core buffer Info,FlowTable Size,Compiler Size + * @param [in] memoryConfig + * @return memoryConfig + */ +RTS_API rtError_t rtGetMemoryConfig(rtMemoryConfig_t *memoryConfig); + + +/** + * @ingroup + * @brief get l2 buffer Info,virtual baseaddr,Size + * @param [in] stream + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtMemGetL2Info(rtStream_t stream, void **ptr, uint32_t *size); + +/** + * @ingroup + * @brief get runtime version. The version is returned as (1000 major + 10 minor). For example, RUNTIME 9.2 would be represented by 9020. + * @param [out] runtimeVersion + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_STREAM_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/context.h b/metadef/third_party/fwkacllib/inc/runtime/context.h new file mode 100644 index 00000000..4be49a8c --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/context.h @@ -0,0 +1,164 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_CONTEXT_H__ +#define __CCE_RUNTIME_CONTEXT_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +/** + * @ingroup rt_context + * @brief runtime context handle. + */ +typedef void *rtContext_t; + +typedef enum tagDryRunFlag { + RT_DRYRUN_FLAG_FALSE = 0, + RT_DRYRUN_FLAG_TRUE = 1, +} rtDryRunFlag_t; + +typedef enum tagCtxMode { + RT_CTX_NORMAL_MODE = 0, + RT_CTX_GEN_MODE = 1, +} rtCtxMode_t; + +typedef struct tagRtGroupInfo { + int32_t groupId; + uint32_t flag; + uint32_t aicoreNum; + uint32_t aicpuNum; + uint32_t aivectorNum; + uint32_t sdmaNum; + uint32_t activeStreamNum; + void* extrPtr; +} rtGroupInfo_t; + +/** + * @ingroup rt_context + * @brief create context and associates it with the calling thread + * @param [out] ctx created context + * @param [in] flags context creation flag. set to 0. + * @param [in] device device to create context on + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxCreate(rtContext_t *ctx, uint32_t flags, int32_t device); + +/** + * @ingroup rt_context + * @brief create context and associates it with the calling thread + * @param [out] ctx created context + * @param [in] flags context creation flag. set to 0. + * @param [in] device device to create context on + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxCreateEx(rtContext_t *ctx, uint32_t flags, int32_t device); + +/** + * @ingroup rt_context + * @brief destroy context instance + * @param [in] ctx context to destroy + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxDestroy(rtContext_t ctx); + +/** + * @ingroup rt_context + * @brief destroy context instance + * @param [in] ctx context to destroy + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxDestroyEx(rtContext_t ctx); + +/** + * @ingroup rt_context + * @brief binds context to the calling CPU thread. + * @param [in] ctx context to bind. if NULL, unbind current context. + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxSetCurrent(rtContext_t ctx); + +/** + * @ingroup rt_context + * @brief block for a context's tasks to complete + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxSynchronize(void); + +/** + * @ingroup rt_context + * @brief returns the context bound to the calling CPU thread. + * @param [out] ctx returned context + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxGetCurrent(rtContext_t *ctx); + +/** + * @ingroup rt_context + * @brief returns the primary context of device. + * @param [out] ctx returned context + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPriCtxByDeviceId(int32_t device, rtContext_t *ctx); + +/** + * @ingroup rt_context + * @brief returns the device ID for the current context + * @param [out] device returned device id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxGetDevice(int32_t *device); + +/** + * @ingroup + * @brief set group id + * @param [in] groupid + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtSetGroup(int32_t groupId); + +/** + * @ingroup + * @brief get group info + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetGroupInfo(int32_t groupId, rtGroupInfo_t *groupInfo, uint32_t count); + +/** + * @ingroup + * @brief get group count + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetGroupCount(uint32_t *count); + +/** + * @ingroup rt_context + * @brief set context INF mode + * @param [in] mode + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetCtxINFMode(bool mode); +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + + +#endif // __CCE_RUNTIME_CONTEXT_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/dev.h b/metadef/third_party/fwkacllib/inc/runtime/dev.h new file mode 100644 index 00000000..d1a91a9b --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/dev.h @@ -0,0 +1,363 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_DEVICE_H__ +#define __CCE_RUNTIME_DEVICE_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#define RT_CAPABILITY_SUPPORT (0x1) +#define RT_CAPABILITY_NOT_SUPPORT (0x0) + +typedef struct tagRTDeviceInfo { + uint8_t env_type; // 0: FPGA 1: EMU 2: ESL + uint32_t ctrl_cpu_ip; + uint32_t ctrl_cpu_id; + uint32_t ctrl_cpu_core_num; + uint32_t ctrl_cpu_endian_little; + uint32_t ts_cpu_core_num; + uint32_t ai_cpu_core_num; + uint32_t ai_core_num; + uint32_t ai_core_freq; + uint32_t ai_cpu_core_id; + uint32_t ai_core_id; + uint32_t aicpu_occupy_bitmap; + uint32_t hardware_version; + uint32_t ts_num; +} rtDeviceInfo_t; + +typedef enum tagRtRunMode { + RT_RUN_MODE_OFFLINE = 0, + RT_RUN_MODE_ONLINE = 1, + RT_RUN_MODE_AICPU_SCHED = 2, + RT_RUN_MODE_RESERVED +} rtRunMode; + +typedef enum tagRtAicpuDeployType { + AICPU_DEPLOY_CROSS_OS = 0x0, + AICPU_DEPLOY_CROSS_PROCESS = 0x1, + AICPU_DEPLOY_CROSS_THREAD = 0x2, + AICPU_DEPLOY_RESERVED +} rtAicpuDeployType_t; + +typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_RSV +} rtFeatureType_t; + +typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO_RSV +} rtMemcpyInfo_t; + +/** + * @ingroup dvrt_dev + * @brief get total device number. + * @param [in|out] count the device number + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDeviceCount(int32_t *count); +/** + * @ingroup dvrt_dev + * @brief get device ids + * @param [in|out] get details of device ids + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for error + */ +RTS_API rtError_t rtGetDeviceIDs(uint32_t *devices, uint32_t len); + +/** + * @ingroup dvrt_dev + * @brief get device infomation. + * @param [in] device the device id + * @param [in] moduleType module type + typedef enum { + MODULE_TYPE_SYSTEM = 0, system info + MODULE_TYPE_AICPU, aicpu info + MODULE_TYPE_CCPU, ccpu_info + MODULE_TYPE_DCPU, dcpu info + MODULE_TYPE_AICORE, AI CORE info + MODULE_TYPE_TSCPU, tscpu info + MODULE_TYPE_PCIE, PCIE info + } DEV_MODULE_TYPE; + * @param [in] infoType info type + typedef enum { + INFO_TYPE_ENV = 0, + INFO_TYPE_VERSION, + INFO_TYPE_MASTERID, + INFO_TYPE_CORE_NUM, + INFO_TYPE_OS_SCHED, + INFO_TYPE_IN_USED, + INFO_TYPE_ERROR_MAP, + INFO_TYPE_OCCUPY, + INFO_TYPE_ID, + INFO_TYPE_IP, + INFO_TYPE_ENDIAN, + } DEV_INFO_TYPE; + * @param [out] value the device info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for error + */ +RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value); + +/** + * @ingroup dvrt_dev + * @brief set target device for current thread + * @param [int] device the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDevice(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief set target device for current thread + * @param [int] device the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDeviceEx(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief get Index by phyId. + * @param [in] phyId the physical device id + * @param [out] devIndex the logic device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDeviceIndexByPhyId(uint32_t phyId, uint32_t *devIndex); + +/** + * @ingroup dvrt_dev + * @brief get phyId by Index. + * @param [in] devIndex the logic device id + * @param [out] phyId the physical device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDevicePhyIdByIndex(uint32_t devIndex, uint32_t *phyId); + +/** + * @ingroup dvrt_dev + * @brief enable direction:devIdDes---->phyIdSrc. + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t flag); + +/** + * @ingroup dvrt_dev + * @brief disable direction:devIdDes---->phyIdSrc. + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); + +/** + * @ingroup dvrt_dev + * @brief get cability of P2P omemry copy betwen device and peeredevic. + * @param [in] device the logical device id + * @param [in] peerDevice the physical device id + * @param [outv] *canAccessPeer 1:enable 0:disable + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceCanAccessPeer(int32_t* canAccessPeer, uint32_t device, uint32_t peerDevice); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @param [in|out] status status value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); + +/** + * @ingroup dvrt_dev + * @brief get value of current thread + * @param [in|out] pid value of pid + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDeviceGetBareTgid(uint32_t *pid); + +/** + * @ingroup dvrt_dev + * @brief get target device of current thread + * @param [in|out] device the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDevice(int32_t *device); + +/** + * @ingroup dvrt_dev + * @brief reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceReset(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief reset opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceResetEx(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief get total device infomation. + * @param [in] device the device id + * @param [in] type limit type RT_LIMIT_TYPE_LOW_POWER_TIMEOUT=0 + * @param [in] value limit value + * @param [out] info the device info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceSetLimit(int32_t device, rtLimitType_t type, uint32_t value); + +/** + * @ingroup dvrt_dev + * @brief Wait for compute device to finish + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceSynchronize(void); + +/** + * @ingroup dvrt_dev + * @brief get priority range of current device + * @param [in|out] leastPriority least priority + * @param [in|out] greatestPriority greatest priority + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceGetStreamPriorityRange(int32_t *leastPriority, int32_t *greatestPriority); + +/** + * @ingroup dvrt_dev + * @brief Set exception handling callback function + * @param [in] callback rtExceptiontype + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetExceptCallback(rtErrorCallback callback); + +/** + * @ingroup dvrt_dev + * @brief Setting Scheduling Type of Graph + * @param [in] tsId the ts id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetTSDevice(uint32_t tsId); + +/** + * @ingroup dvrt_dev + * @brief init aicpu executor + * @param [out] runtime run mode + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get run mode + */ +RTS_API rtError_t rtGetRunMode(rtRunMode *mode); + +/** + * @ingroup dvrt_dev + * @brief get aicpu deploy + * @param [out] aicpu deploy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get aicpu deploy + */ +RTS_API rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deployType); + +/** + * @ingroup dvrt_dev + * @brief set chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetSocVersion(const char *version); + +/** + * @ingroup dvrt_dev + * @brief get chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetSocVersion(char *version, const uint32_t maxLen); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devId the logical device id + * @param [in] otherDevId the other logical device id + * @param [in] infoType info type + * @param [in|out] value pair info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPairDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *value); + +/** + * @ingroup dvrt_dev + * @brief get capability infomation. + * @param [in] featureType feature type + typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_RSV, + } rtFeatureType_t; + * @param [in] featureInfo info type + typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO _RSV, + } rtMemcpyInfo_t; + * @param [out] value the capability info RT_CAPABILITY_SUPPORT or RT_CAPABILITY_NOT_SUPPORT + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *value); + +/** + * @ingroup dvrt_dev + * @brief set target device for current thread + * @param [int] device the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDeviceWithoutTsd(int32_t device); + +/** + * @ingroup dvrt_dev + * @brief reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceResetWithoutTsd(int32_t device); +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_DEVICE_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/dvfsprofile.h b/metadef/third_party/fwkacllib/inc/runtime/dvfsprofile.h new file mode 100644 index 00000000..6e451695 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/dvfsprofile.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_DVFSPROFILE_H__ +#define __CCE_RUNTIME_DVFSPROFILE_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +typedef enum dvfsProfileMode { + DVFS_PROFILE_PERFORMANCE_PRIORITY, + DVFS_PROFILE_BALANCE_PRIORITY, + DVFS_PROFILE_POWER_PRIORITY, + DVFS_PROFILE_PRIORITY_MAX +} DvfsProfileMode; + +/** + * @ingroup dvrt_dvfsprofile + * @brief Set the performance mode of the device + * @param [in] mode dvfsProfileMode + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDvfsProfile(DvfsProfileMode mode); + +/** + * @ingroup dvrt_dvfsprofile + * @brief Set the performance mode of the device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for invalid value + */ +RTS_API rtError_t rtUnsetDvfsProfile(); + +/** + * @ingroup dvrt_dvfsprofile + * @brief Get the current performance mode of the device + * @param [in|out] pmode dvfsProfileMode type pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDvfsProfile(DvfsProfileMode *pmode); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_PROFILE_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/event.h b/metadef/third_party/fwkacllib/inc/runtime/event.h new file mode 100644 index 00000000..41e611ea --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/event.h @@ -0,0 +1,246 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_EVENT_H__ +#define __CCE_RUNTIME_EVENT_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +/** + * @ingroup event_flags + * @brief event op bit flags + */ +#define RT_EVENT_DEFAULT (0x00) +#define RT_EVENT_WITH_FLAG (0x01) + +/** + * @ingroup dvrt_event + * @brief create event instance + * @param [in|out] event created event + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventCreate(rtEvent_t *event); + +/** + * @ingroup dvrt_event + * @brief create event instance with flag + * @param [in|out] event created event flag event op flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventCreateWithFlag(rtEvent_t *event, uint32_t flag); + +/** + * @ingroup dvrt_event + * @brief destroy event instance + * @param [in] event event to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventDestroy(rtEvent_t event); + +/** + * @ingroup dvrt_event + * @brief get event id + * @param [in] event_ event to be get + * @param [in|out] event_id event_id id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetEventID(rtEvent_t event, uint32_t *eventId); + +/** + * @ingroup dvrt_event + * @brief event record + * @param [int] event event to record + * @param [int] stream stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventRecord(rtEvent_t event, rtStream_t stream); + +/** + * @ingroup dvrt_event + * @brief event reset + * @param [int] event event to reset + * @param [int] stream stream handle + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtEventReset(rtEvent_t event, rtStream_t stream); + +/** + * @ingroup dvrt_event + * @brief wait event to be complete + * @param [in] event event to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventSynchronize(rtEvent_t event); + +/** + * @ingroup dvrt_event + * @brief Queries an event's status + * @param [in] event event to query + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_EVENT_NOT_COMPLETE for not complete + */ +RTS_API rtError_t rtEventQuery(rtEvent_t event); + +/** + * @ingroup dvrt_event + * @brief computes the elapsed time between events. + * @param [in] time time between start and end in ms + * @param [in] start starting event + * @param [in] end ending event + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtEventElapsedTime(float *time, rtEvent_t start, rtEvent_t end); + +/** + * @ingroup dvrt_event + * @brief get the elapsed time from a event after event recorded. + * @param [in] time time in ms + * @param [in] event event handle + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtEventGetTimeStamp(uint64_t *time, rtEvent_t event); + +/** + * @ingroup dvrt_event + * @brief name an event + * @param [in] event event to be named + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input of event, name + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNameEvent(rtEvent_t event, const char *name); + +/** + * @ingroup dvrt_event + * @brief Create a notify + * @param [in] device_id device id + * @param [in|out] notify_ notify to be created + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyCreate(int32_t deviceId, rtNotify_t *notify); + +/** + * @ingroup dvrt_event + * @brief Destroy a notify + * @param [in] notify_ notify to be destroyed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNotifyDestroy(rtNotify_t notify); + +/** + * @ingroup dvrt_event + * @brief Record a notify + * @param [in] notify_ notify to be recorded + * @param [in] stream_ input stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stream); + +/** + * @ingroup dvrt_event + * @brief Wait for a notify + * @param [in] notify_ notify to be wait + * @param [in] stream_ input stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stream); + +/** + * @ingroup dvrt_event + * @brief Name a notify + * @param [in] notify_ notify to be named + * @param [in|out] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNameNotify(rtNotify_t notify, const char *name); + +/** + * @ingroup dvrt_event + * @brief get notify id + * @param [in] notify_ notify to be get + * @param [in|out] notify_id notify id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetNotifyID(rtNotify_t notify, uint32_t *notifyId); + +/** + * @ingroup dvrt_event + * @brief Set a notify to IPC notify + * @param [in] notify_ notify to be set to IPC notify + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input of + */ +RTS_API rtError_t rtIpcSetNotifyName(rtNotify_t notify, char *name, uint32_t len); + +/** + * @ingroup dvrt_event + * @brief Open IPC notify + * @param [out] notify the opened notify + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtIpcOpenNotify(rtNotify_t *notify, const char *name); + +/** + * @ingroup dvrt_event + * @brief Get the physical address corresponding to notify + * @param [in] notify notify to be queried + * @param [in] devAddrOffset device physical address offset + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNotifyGetAddrOffset(rtNotify_t notify, uint64_t *devAddrOffset); + +/** + * @ingroup dvrt_event + * @brief Ipc set notify pid + * @param [in] name name to be queried + * @param [in] pid process id + * @param [in] num length of pid[] + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcNotifyPid(const char *name, int32_t pid[], int num); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_EVENT_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/kernel.h b/metadef/third_party/fwkacllib/inc/runtime/kernel.h new file mode 100644 index 00000000..5f519442 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/kernel.h @@ -0,0 +1,566 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_KERNEL_H__ +#define __CCE_RUNTIME_KERNEL_H__ + +#include "base.h" +#include "stream.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +/** + * @ingroup rt_kernel + * @brief shared memory data control + */ +typedef struct tagRtSmData { + uint64_t L2_mirror_addr; // preload or swap source address + uint32_t L2_data_section_size; // every data size + uint8_t L2_preload; // 1 - preload from mirrorAddr, 0 - no preload + uint8_t modified; // 1 - data will be modified by kernel, 0 - no modified + uint8_t priority; // data priority + int8_t prev_L2_page_offset_base; // remap source section offset + uint8_t L2_page_offset_base; // remap destination section offset + uint8_t L2_load_to_ddr; // 1 - need load out, 0 - no need + uint8_t reserved[2]; // reserved +} rtSmData_t; + +/** + * @ingroup rt_kernel + * @brief shared memory description + */ +typedef struct tagRtSmCtrl { + rtSmData_t data[8]; // data description + uint64_t size; // max page Num + uint8_t remap[64]; /* just using for static remap mode, default:0xFF + array index: virtual l2 page id, array value: physic l2 page id */ + uint8_t l2_in_main; // 0-DDR, 1-L2, default:0xFF + uint8_t reserved[3]; +} rtSmDesc_t; + +typedef rtSmDesc_t rtL2Ctrl_t; + +/** + * @ingroup rt_kernel + * @brief device binary type + */ +typedef struct tagRtDevBinary { + uint32_t magic; // magic number + uint32_t version; // version of binary + const void *data; // binary data + uint64_t length; // binary length +} rtDevBinary_t; + +/** + * @ingroup rt_kernel + * @brief function mode type + */ +#define ONLINE_PROF_MAX_PMU_NUM (8) + +typedef struct ProfilefDataInfo { + const void *stubFunc; + uint32_t blockDim; + const void *args; + uint32_t argsSize; + rtSmDesc_t *smDesc; + rtStream_t stream; + uint64_t totalcycle; + uint64_t ovcycle; + uint64_t pmu_cnt[ONLINE_PROF_MAX_PMU_NUM]; +} rtProfDataInfo_t; + +/** + * @ingroup rt_kernel + * @brief function mode type + */ +typedef enum { + FUNC_MODE_NORMAL = 0, + FUNC_MODE_PCTRACE_USERPROFILE_RECORDLOOP, + FUNC_MODE_PCTRACE_USERPROFILE_SKIPLOOP, + FUNC_MODE_PCTRACE_CYCLECNT_RECORDLOOP, + FUNC_MODE_PCTRACE_CYCLECNT_SKIPLOOP, + FUNC_MODE_BUTT +} rtFuncModeType_t; + +/** + * @ingroup rt_kernel + * @brief kernel info + */ +typedef struct rtKernelInfo { + uint64_t task_offset; // kernel offset in module + /* flowtable */ + void *arg; // launch kernel arg + uint32_t arg_size; + /* module */ + void *module_addr; // module::baseaddr_ + uint32_t module_size; +} * rtKernelInfo_t; + +/** + * @ingroup rt_KernelConfigDump + * @brief device dump type + */ +typedef enum tagRtDumpKind { + RT_DATA_DUMP_KIND_INVALID = -1, + RT_DATA_DUMP_KIND_DUMP = 0, + RT_DATA_DUMP_KIND_RESERVED +} rtDumpKind_t; + +/** + * @ingroup rt_kernel + * @brief report callback + */ +typedef rtError_t (*rtKernelReportCallback)(rtStream_t stream, rtKernelInfo_t kernelInfo); + +/** + * @ingroup rt_kernel + * @brief stream report callback + */ +typedef void (*rtCallback_t)(void *fnData); + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_PLAIN 0xabceed50 + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AICPU 0xabceed51 + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AIVEC 0xabceed52 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_ELF 0x43554245 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_ELF_AICPU 0x41415243 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_ELF_AIVEC 0x41415246 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicube + */ +#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41415247 + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_ELF_AIVECTOR 0x41415248 + +/** + * @ingroup rt_kernel_flags + * @brief kernel op bit flags + */ +#define RT_KERNEL_DEFAULT (0x00) +#define RT_KERNEL_CONVERT (0x01) +#define RT_KERNEL_DUMPFLAG (0x02) +#define RT_FUSION_KERNEL_DUMPFLAG (0x04) +#define RT_KERNEL_CUSTOM_AICPU (0x08) + +/** + * @ingroup rt_kernel + * @brief kernel L1 Fusion Dump bit flags + */ +#define RT_DDR_ADDR (0x0) + +/** + * @ingroup rt_kernel + * @brief register device binary + * @param [in] bin device binary description + * @param [out] handle device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **handle); + +/** + * @ingroup rt_kernel + * @brief register fast memeory device binary + * @param [in] handle device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtBinaryRegisterToFastMemory(void *handle); + +/** + * @ingroup rt_kernel + * @brief unregister device binary + * @param [in] handle device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDevBinaryUnRegister(void *handle); + +/** + * @ingroup rt_kernel + * @brief register device binary metadata + * @param [in] handle device binary description + * @param [in] metadata device binary metadata + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMetadataRegister(void *handle, const char *metadata); + +/** + * @ingroup rt_kernel + * @brief register device binary dependency + * @param [in] mHandle master device binary description + * @param [in] sHandle slave device binary description + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDependencyRegister(void *mHandle, void *sHandle); + +/** + * @ingroup rt_kernel + * @brief register device function + * @param [in] binHandle device binary handle + * @param [in] stubFunc stub function + * @param [in] stubName stub function name + * @param [in] devFunc device function description. symbol name or address + * offset, depending binary type. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char *stubName, const void *devFunc, + uint32_t funcMode); + +/** + * @ingroup rt_kernel + * @brief find stub function by name + * @param [in] stubName stub function name + * @param [out] stubFunc stub function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetFunctionByName(const char *stubName, void **stubFunc); + +/** + * @ingroup rt_kernel + * @brief find addr by stub func + * @param [in] stubFunc stub function + * @param [out] addr + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetAddrByFun(const void *stubFunc, void **addr); +/** + * @ingroup rt_kernel + * @brief query registered or not by stubName + * @param [in] stubName stub function name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtQueryFunctionRegistered(const char *stubName); + +/** + * @ingroup rt_kernel + * @brief config data dump + * @param [in] dumpSizePerBlock dump size + * @param [in] blockDim block dimentions + * @param [in] dumpBaseAddr dump base address + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelConfigDump(uint32_t kind, uint32_t dumpSizePerBlock, uint32_t blockDim, void **dumpBaseAddr, + rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, + rtSmDesc_t *smDesc, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @param [in] flag dump flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, + rtSmDesc_t *smDesc, rtStream_t stream, uint32_t flags); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] flags launch flags + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief launch cpu kernel to device + * @param [in] soName so name + * @param [in] kernelName kernel name + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCpuKernelLaunch(const void *soName, const void *kernelName, uint32_t blockDim, const void *args, + uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief launch cpu kernel to device with dump identifier + * @param [in] soName so name + * @param [in] kernelName kernel name + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argments size + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @param [in] flag dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim, + const void *args, uint32_t argsSize, rtSmDesc_t *smDesc, rtStream_t stream, + uint32_t flags); + +typedef void *rtModel_t; +/** + * @ingroup rt_kernel + * @brief L1 fusion dump addr transfered to device + * @param [in] model handle info + * @param [in] addr ddr address of L1 Fusion Dump + * @param [in] dumpSize memory size + * @param [in] flag memory flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + RTS_API rtError_t rtDumpAddrSet(rtModel_t model, void *addr, uint32_t dumpSize, uint32_t flag); + +/** + * @ingroup rt_kernel + * @brief load dump info to aicpu + * @param [in] dumpInfo dump info + * @param [in] length length of dump info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDatadumpInfoLoad(const void *dumpInfo, uint32_t length); + +#ifndef __CLANG_CCE_RUNTIME_H__ +#define __CLANG_CCE_RUNTIME_H__ +/** + * @ingroup rt_kernel + * @brief configure call argment for next rtLaunch in current thread + * @param [in] numBlocks block dimentions + * @param [in] smDesc shared memory description + * @param [in] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +#ifdef __cplusplus +RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc = nullptr, rtStream_t stream = nullptr); +#else +RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc, rtStream_t stream); +#endif +#endif // __CLANG_CCE_RUNTIME_H__ + +/** + * @ingroup rt_kernel + * @brief setup argment for next rtLaunch in current thread + * @param [in] arg argment address for kernel function + * @param [in] size argment size + * @param [in] offset argment table offset + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetupArgument(const void *arg, uint32_t size, uint32_t offset); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device with previous setting kernel argment + * and call argment + * @param [in] stubFunc stub function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLaunch(const void *stubFunc); + +/** + * @ingroup rt_kernel + * @brief implicitly transfered data to device. + * lifecycle end after next kernel task finish + * @param [in] ptr host memory + * @param [in] size host memory size + * @param [in] flag reserved. set to 0 + * @param [out] arg returned arg. used for next kernel's arg. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelConfigTransArg(const void *ptr, uint64_t size, uint32_t flag, void **arg); + +/** + * @ingroup rt_kernel + * @brief start fusion kernels. + * @param [in] stream stream for fusion kernels + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelFusionStart(rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief end fusion kernels. + * @param [in] stream stream for fusion kernels + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelFusionEnd(rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief set kernelinfo callback + * @param [in] callback + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetKernelReportCallback(rtKernelReportCallback callBack); + +/** + * @ingroup rt_kernel + * @brief subscribe stream callback report. + * @param [in] threadId thread id for stream + * @param [in] stream stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief add callback launch task in stream. + * @param [in] callBackFunc app callback function + * @param [in] fnData user data + * @param [in] stream subscribed stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream, bool isBlock); + +/** + * @ingroup rt_kernel + * @brief process callback report. + * @param [in] timeout if timeout=-1, while(1); else timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtProcessReport(int32_t timeout); + +/** + * @ingroup rt_kernel + * @brief unsubscribe callback report. + * @param [in] threadId thread id for stream + * @param [in] stream stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtUnSubscribeReport(uint64_t threadId, rtStream_t stream); + +/** + * @ingroup profiling_base + * @brief start online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStartOnlineProf(rtStream_t stream, uint32_t sampleNum); + +/** + * @ingroup profiling_base + * @brief stop online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStopOnlineProf(rtStream_t stream); + +/** + * @ingroup profiling_base + * @brief get online prof. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetOnlineProfData(rtStream_t stream, rtProfDataInfo_t *pProfData, uint32_t profDataNum); + +/** + * @ingroup profiling_base + * @brief start mdc profiler. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStartMDCProfiler(void **addr, uint32_t length); + +/** + * @ingroup profiling_base + * @brief stop mdc profiler. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStopMDCProfiler(void *addr); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_KERNEL_H__ + diff --git a/metadef/third_party/fwkacllib/inc/runtime/mem.h b/metadef/third_party/fwkacllib/inc/runtime/mem.h new file mode 100644 index 00000000..e65d8604 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/mem.h @@ -0,0 +1,542 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_MEM_H__ +#define __CCE_RUNTIME_MEM_H__ + +/*lint -e7*/ +#include +/*lint +e7*/ +#include "base.h" +#include "config.h" +#include "stream.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +/** + * @ingroup dvrt_mem + * @brief memory type + */ +#define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device +#define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device +#define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device +#define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device +#define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device +#define RT_MEMORY_P2P_DDR ((uint32_t)0x11) // DDR memory on other device +#define RT_MEMORY_DDR_NC ((uint32_t)0x20) // DDR memory of non-cache +#define RT_MEMORY_TS_4G ((uint32_t)0x40) +#define RT_MEMORY_TS ((uint32_t)0x80) +#define RT_MEMORY_RESERVED ((uint32_t)0x100) + +#define RT_MEMORY_L1 ((uint32_t)0x1<<16) +#define RT_MEMORY_L2 ((uint32_t)0x1<<17) + +/** + * @ingroup dvrt_mem + * @brief memory info type + */ +#define RT_MEM_INFO_TYPE_DDR_SIZE ((uint32_t)0x1) +#define RT_MEM_INFO_TYPE_HBM_SIZE ((uint32_t)0x2) +#define RT_MEM_INFO_TYPE_DDR_P2P_SIZE ((uint32_t)0x3) +#define RT_MEM_INFO_TYPE_HBM_P2P_SIZE ((uint32_t)0x4) + +/** + * @ingroup dvrt_mem + * @brief memory Policy + */ +#define RT_MEMORY_POLICY_NONE ((uint32_t)0x0) // Malloc mem prior hage page, then default page +#define RT_MEMORY_POLICY_HUGE_PAGE_FIRST ((uint32_t)0x1 << 10) // Malloc mem prior hage page, then default page +#define RT_MEMORY_POLICY_HUGE_PAGE_ONLY ((uint32_t)0x1 << 11) // Malloc mem only use hage page +#define RT_MEMORY_POLICY_DEFAULT_PAGE_ONLY ((uint32_t)0x1 << 12) // Malloc mem only use default page +#define RT_MEMORY_POLICY_HUGE_PAGE_FIRST_P2P ((uint32_t)0x1 << 13) // Malloc mem prior hage page, then default page, use for p2p +#define RT_MEMORY_POLICY_HUGE_PAGE_ONLY_P2P ((uint32_t)0x1 << 14) // Malloc mem only use hage page, use for p2p +#define RT_MEMORY_POLICY_DEFAULT_PAGE_ONLY_P2P ((uint32_t)0x1 << 15) // Malloc mem only use default page, use for p2p + +#define MEM_ALLOC_TYPE_BIT ((uint32_t)0x3FF) // mem type bit in <0, 9> + +/** + * @ingroup dvrt_mem + * @brief memory type | memory Policy + */ +typedef uint32_t rtMemType_t; + +/** + * @ingroup dvrt_mem + * @brief memory advise type + */ +#define RT_MEMORY_ADVISE_EXE (0x02) +#define RT_MEMORY_ADVISE_THP (0x04) +#define RT_MEMORY_ADVISE_PLE (0x08) +#define RT_MEMORY_ADVISE_PIN (0x16) + +/** + * @ingroup dvrt_mem + * @brief memory copy type + */ +typedef enum tagRtMemcpyKind { + RT_MEMCPY_HOST_TO_HOST = 0, // host to host + RT_MEMCPY_HOST_TO_DEVICE, // host to device + RT_MEMCPY_DEVICE_TO_HOST, // device to host + RT_MEMCPY_DEVICE_TO_DEVICE, // device to device, 1P && P2P + RT_MEMCPY_MANAGED, // managed memory + RT_MEMCPY_ADDR_DEVICE_TO_DEVICE, + RT_MEMCPY_HOST_TO_DEVICE_EX, // host to device ex (only used for 8 bytes) + RT_MEMCPY_DEVICE_TO_HOST_EX, // device to host ex + RT_MEMCPY_RESERVED, +} rtMemcpyKind_t; + +typedef enum tagRtMemInfoType { + RT_MEMORYINFO_DDR, + RT_MEMORYINFO_HBM, + RT_MEMORYINFO_DDR_HUGE, // Hugepage memory of DDR + RT_MEMORYINFO_DDR_NORMAL, // Normal memory of DDR + RT_MEMORYINFO_HBM_HUGE, // Hugepage memory of HBM + RT_MEMORYINFO_HBM_NORMAL, // Normal memory of HBM + RT_MEMORYINFO_DDR_P2P_HUGE, // Hugepage memory of DDR + RT_MEMORYINFO_DDR_P2P_NORMAL, // Normal memory of DDR + RT_MEMORYINFO_HBM_P2P_HUGE, // Hugepage memory of HBM + RT_MEMORYINFO_HBM_P2P_NORMAL, // Normal memory of HBM +} rtMemInfoType_t; + +typedef enum tagRtRecudeKind { + RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P + RT_RECUDE_KIND_END +} rtRecudeKind_t; + +typedef enum tagRtDataType { + RT_DATA_TYPE_FP32 = 0, // fp32 + RT_DATA_TYPE_FP16 = 1, // fp16 + RT_DATA_TYPE_INT16 = 2, // int16 + RT_DATA_TYPE_END +} rtDataType_t; + +/** + * @ingroup dvrt_mem + * @brief memory copy channel type + */ +typedef enum tagRtMemcpyChannelType { + RT_MEMCPY_CHANNEL_TYPE_INNER = 0, // 1P + RT_MEMCPY_CHANNEL_TYPE_PCIe, + RT_MEMCPY_CHANNEL_TYPE_HCCs, // not support now + RT_MEMCPY_CHANNEL_TYPE_RESERVED, +} rtMemcpyChannelType_t; + +/** + * @ingroup rt_kernel + * @brief ai core memory size + */ +typedef struct rtAiCoreMemorySize { + uint32_t l0ASize; + uint32_t l0BSize; + uint32_t l0CSize; + uint32_t l1Size; + uint32_t ubSize; + uint32_t l2Size; + uint32_t l2PageNum; + uint32_t blockSize; + uint64_t bankSize; + uint64_t bankNum; + uint64_t burstInOneBlock; + uint64_t bankGroupNum; +} rtAiCoreMemorySize_t; + +/** + * @ingroup dvrt_mem + * @brief memory type + */ +typedef enum tagRtMemoryType { + RT_MEMORY_TYPE_HOST = 1, + RT_MEMORY_TYPE_DEVICE = 2 , + RT_MEMORY_TYPE_SVM = 3, + RT_MEMORY_TYPE_DVPP = 4 +} rtMemoryType_t; + +/** + * @ingroup dvrt_mem + * @brief memory attribute + */ +typedef struct tagRtPointerAttributes { + rtMemoryType_t memoryType; // host memory or device memory + rtMemoryType_t locationType; + uint32_t deviceID; // device ID + uint32_t pageSize; +} rtPointerAttributes_t; + + +typedef struct rtMallocHostSharedMemoryIn { + const char* name; + const uint64_t size; + uint32_t flag; +} rtMallocHostSharedMemoryIn; + +typedef struct rtMallocHostSharedMemoryOut { + int fd; + void* ptr; + void* devPtr; +} rtMallocHostSharedMemoryOut; + +typedef struct rtFreeHostSharedMemoryIn { + const char* name; + const uint64_t size; + int fd; + void* ptr; + void* devPtr; +} rtFreeHostSharedMemoryIn; + + +/** + * @ingroup dvrt_mem + * @brief alloc device memory + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMalloc(void **devPtr, uint64_t size, rtMemType_t type); + +/** + * @ingroup dvrt_mem + * @brief free device memory + * @param [in|out] devPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFree(void *devPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc device memory for dvpp + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDvppMalloc(void **devPtr, uint64_t size); + +/** + * @ingroup dvrt_mem + * @brief free device memory for dvpp + * @param [in|out] devPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDvppFree(void *devPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc host memory + * @param [in|out] hostPtr memory pointer + * @param [in] size memory size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMallocHost(void **hostPtr, uint64_t size); + +/** + * @ingroup dvrt_mem + * @brief free host memory + * @param [in] hostPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFreeHost(void *hostPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc host shared memory + * @param [in] in alloc host shared memory inputPara pointer + * @param [in] out alloc host shared memory outputInfo pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + +RTS_API rtError_t rtMallocHostSharedMemory(rtMallocHostSharedMemoryIn *in, + rtMallocHostSharedMemoryOut *out); + +/** + * @ingroup dvrt_mem + * @brief free host memory + * @param [in] in free host shared memory inputPara pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + +RTS_API rtError_t rtFreeHostSharedMemory(rtFreeHostSharedMemoryIn *in); + +/** + * @ingroup dvrt_mem + * @brief alloc managed memory + * @param [in|out] ptr memory pointer + * @param [in] size memory size + * @param [in] flag reserved, set to 0. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag); + +/** + * @ingroup dvrt_mem + * @brief free managed memory + * @param [in] ptr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemFreeManaged(void *ptr); +/** + * @ingroup dvrt_mem + * @brief alloc cached device memory + * @param [in| devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMallocCached(void **devPtr, uint64_t size, rtMemType_t type); + +/** + * @ingroup dvrt_mem + * @brief flush device mempory + * @param [in] base virtal base address + * @param [in] len memory size + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtFlushCache(void *base, size_t len); + +/** + * @ingroup dvrt_mem + * @brief invalid device mempory + * @param [in] base virtal base address + * @param [in] len memory size + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtInvalidCache(void *base, size_t len); + +/** + * @ingroup dvrt_mem + * @brief synchronized memcpy + * @param [in] dst destination address pointer + * @param [in] Max length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpy(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] Max length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stream asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsync(void *dst, uint64_t destMax, const void *src, uint64_t count, rtMemcpyKind_t kind, + rtStream_t stream); + +/** + * @ingroup dvrt_mem + * @brief asynchronized reduce memcpy + * @param [in] dst destination address pointer + * @param [in] Max length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] type data type + * @param [in] stream asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtReduceAsync(void *dst, uint64_t destMax, const void *src, uint64_t count, rtRecudeKind_t kind, + rtDataType_t type, rtStream_t stream); + +/** + * @ingroup dvrt_mem + * @brief query memory size + * @param [in] aiCoreMemorySize + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); + +/** + * @ingroup dvrt_mem + * @brief set memory size, Setting before model reasoning, Bright screen to prevent model can not be fully + integrated network due to memory limitations.Requirement come from JiaMinHu.Only use for Tiny. + * @param [in] aiCoreMemorySize + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); + +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value + * @param [in] devPtr + * @param [in] Max length of destination address memory + * @param [in] value + * @param [in] count byte num + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemset(void *devPtr, uint64_t destMax, uint32_t value, uint64_t count); + +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value async + * @param [in] devPtr + * @param [in] Max length of destination address memory + * @param [in] value + * @param [in] count byte num + * @param [in] stream + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream); + +/** + * @ingroup dvrt_mem + * @brief get current device memory total and free + * @param [out] free + * @param [out] total + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemGetInfo(size_t *free, size_t *total); + +/** + * @ingroup dvrt_mem + * @brief get current device memory total and free + * @param [in] memInfoType + * @param [out] free + * @param [out] total + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total); + +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value + * @param [in] devPtr + * @param [in] len + * @param [in] device + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemPrefetchToDevice(void *devPtr, uint64_t len, int32_t device); + +/** + * @ingroup dvrt_mem + * @brief get memory attribute:Host or Device + * @param [in] ptr + * @param [out] attributes + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtPointerGetAttributes(rtPointerAttributes_t *attributes, const void *ptr); + +/** + * @ingroup dvrt_mem + * @brief make memory shared interprocess and assigned a name + * @param [in] ptr device memory address pointer + * @param [in] name identification name + * @param [in] byteCount identification byteCount + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcSetMemoryName(const void *ptr, uint64_t byteCount, char *name, uint32_t len); + +/** + * @ingroup dvrt_mem + * @brief destroy a interprocess shared memory + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcDestroyMemoryName(const char *name); + +/** + * @ingroup dvrt_mem + * @brief open a interprocess shared memory + * @param [in|out] ptr device memory address pointer + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcOpenMemory(void **ptr, const char *name); + +/** + * @ingroup dvrt_mem + * @brief close a interprocess shared memory + * @param [in] ptr device memory address pointer + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcCloseMemory(const void *ptr); + +/** + * @ingroup dvrt_mem + * @brief HCCL Async memory cpy + * @param [in] index sq index + * @param [in] wqeIndex moudle index + * @param [in] stream asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtRDMASend(uint32_t index, uint32_t wqeIndex, rtStream_t stream); + +/** + * @ingroup dvrt_mem + * @brief Ipc set mem pid + * @param [in] name name to be queried + * @param [in] pid process id + * @param [in] num length of pid[] + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcMemPid(const char *name, int32_t pid[], int num); + +/** + * @ingroup dvrt_mem + * @brief HCCL Async memory cpy + * @param [in] dbindex single device 0 + * @param [in] dbinfo doorbell info + * @param [in] stream asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtRDMADBSend(uint32_t dbIndex, uint64_t dbInfo, rtStream_t stream); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_MEM_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/rt.h b/metadef/third_party/fwkacllib/inc/runtime/rt.h new file mode 100644 index 00000000..83cafa3c --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/rt.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_RT_H__ +#define __CCE_RUNTIME_RT_H__ + +#include "base.h" +#include "config.h" +#include "context.h" +#include "dev.h" +#include "dvfsprofile.h" +#include "event.h" +#include "kernel.h" +#include "mem.h" +#include "rt_model.h" +#include "stream.h" + +#endif // __CCE_RUNTIME_RT_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/rt_model.h b/metadef/third_party/fwkacllib/inc/runtime/rt_model.h new file mode 100644 index 00000000..b72b142d --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/rt_model.h @@ -0,0 +1,457 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_MODEL_H__ +#define __CCE_RUNTIME_MODEL_H__ + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +typedef enum tagModelTaskType { + RT_MODEL_TASK_KERNEL = 0, + RT_MODEL_TASK_EVENT_RECORD, + RT_MODEL_TASK_EVENT_WAIT, + RT_MODEL_TASK_FUSION_START, + RT_MODEL_TASK_FUSION_END, + RT_MODEL_TASK_KERNEL_EX, + RT_MODEL_TASK_HCCL, + RT_MODEL_TASK_STREAM_SWITCH, + RT_MODEL_TASK_STREAM_ACTIVE, + RT_MODEL_TASK_LABEL_SET, + RT_MODEL_TASK_LABEL_SWITCH, + RT_MODEL_TASK_LABEL_GOTO, + RT_MODEL_TASK_PROFILER_TRACE, + RT_MODEL_TASK_MEMCPY_ASYNC, + RT_MODEL_TASK_NOTIFY_RECORD, + RT_MODEL_TASK_NOTIFY_WAIT, + RT_MODEL_TASK_REDUCE_ASYNC, + RT_MODEL_TASK_RDMA_SEND, + RT_MODEL_TASK_EVENT_RESET = 18, + RT_MODEL_TASK_MODEL_END_GRAPH, + RT_MODEL_TASK_STREAM_SWITCH_N, + RT_MODEL_TASK_RDMA_DB_SEND, + RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, + RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, + RT_MODEL_TASK_STREAM_LABEL_GOTO, + RT_MODEL_TASK_MODEL_EXIT, +} rtModelTaskType_t; + +typedef enum tagModelStreamType { + RT_MODEL_HEAD_STREAM = 0, + RT_MODEL_WAIT_ACTIVE_STREAM = 1 +} rtModelStreamType_t; + +typedef enum tagModelQueueFlag { + RT_MODEL_INPUT_QUEUE = 0, + RT_MODEL_OUTPUT_QUEUE = 1 +} rtModelQueueFlag_t; + +#define EXECUTOR_NONE ((uint32_t)0x0) +#define EXECUTOR_TS ((uint32_t)0x01) +#define EXECUTOR_AICPU ((uint32_t)0x02) + +/* + * @ingroup rt_model + * @brief debug flag for kernel exception dump + */ +#define RT_DEBUG_FLAG_AICORE_OVERFLOW (0x1 << 0) +#define RT_DEBUG_FLAG_ATOMIC_ADD_OVERFLOW (0x1 << 1) + +/** + * @ingroup + * @brief the type defination of aicpu model task command + */ +typedef enum tagTsAicpuModelCmd { + TS_AICPU_MODEL_LOAD = 1, + TS_AICPU_MODEL_EXECUTE, + TS_AICPU_MODEL_DESTROY, + TS_AICPU_MODEL_ABORT, + TS_AICPU_MODEL_RESERVED, +} tsAicpuModelCmd; + +typedef struct tagAicpuTaskInfo { + uint32_t taskID; + uint32_t streamID; + uint32_t kernelType; + uint64_t kernelName; + uint64_t kernelSo; + uint64_t paraBase; + uint32_t taskFlag; +} rtAicpuTaskInfo_t; + +typedef struct tagModelStreamInfo { + uint32_t streamID; + uint32_t streamFlag; +} rtModelStreamInfo_t; + +typedef struct tagModelQueueInfo { + uint32_t queueID; + uint32_t flag; +} rtModelQueueInfo_t; + +typedef struct tagAicpuModelInfo { + uint32_t moduleID; + uint32_t tsId; + uint16_t streamInfoNum; + uint16_t aicpuTaskNum; + uint64_t streamInfoPtr; + uint64_t aicpuTaskPtr; + uint16_t queueSize; + uint64_t queueInfoPtr; +} rtAicpuModelInfo_t; + +typedef struct tagKernelTaskInfo { + uint16_t blockDim; + uint16_t argsCount; + uint16_t argsSize; + uint16_t reserved; + char *stubFunc; + uint8_t *smDesc; + uint8_t *args; + uint16_t *argsOffset; +} rtKernelTaskInfo_t; + +typedef struct tagKernelTaskInfoEx { + uint32_t flags; + uint32_t argsSize; + void *args; + uint32_t reserved[6]; +} rtKernelTaskInfoEx_t; + +typedef struct tagEventTaskInfo { + uint32_t eventID; + uint32_t reserved[9]; +} rtEventTaskInfo_t; + +typedef struct tagStreamSwitchTaskInfo { + int64_t value; + uint64_t pValuePtr; + uint32_t trueStreamID; + uint32_t dataType; + uint32_t reserved[4]; +} rtStreamSwitchTaskInfo_t; + +typedef struct tagStreamSwitchNTaskInfo { + uint64_t pValuePtr; + uint64_t pTrueStreamPtr; + uint32_t size; + uint32_t elementSize; + uint32_t dataType; + uint32_t reserved[3]; +} rtStreamSwitchNTaskInfo_t; + +typedef struct tagStreamActiveTaskInfo { + uint32_t activeStreamID; + uint32_t reserved[9]; +} rtStreamActiveTaskInfo_t; + +typedef struct tagSetTaskInfo { + uint16_t labelId; + uint32_t reserved[9]; +} rtLabelSetTaskInfo_t; + +typedef struct tagSwitchTaskInfo { + uint32_t value; + uint32_t reserved[9]; +} rtLabelSwitchTaskInfo_t; + +typedef struct tagLabelGotoTaskInfo { + uint16_t labelId; + uint32_t reserved[9]; +} rtLabelGotoTaskInfo_t; + +typedef struct tagProfilerTraceTaskInfo { + uint64_t profilerTraceId; + uint32_t notify : 8; + uint32_t reserved_ : 24; + uint32_t flags; + uint32_t reserved[6]; +} rtProfilerTrace_t; + +typedef struct tagrtMemcpyAsyncTaskInfo { + void *dst; + uint64_t destMax; + void *src; + uint64_t count; + uint32_t kind; + uint32_t reserved; +} rtMemcpyAsyncTaskInfo_t; + +typedef struct tagrtNotifyTaskInfo { + uint32_t notifyID; + uint32_t reserved[9]; +} rtNotifyTaskInfo_t; + +typedef struct tagrtReduceAsyncTaskInfo { + void *dst; + uint64_t destMax; + void *src; + uint64_t count; + uint32_t kind; + uint32_t type; +} rtReduceAsyncTaskInfo_t; + +typedef struct tagrtRdmaSendTaskInfo { + uint32_t index; + uint32_t wqe_index; + uint32_t reserved[8]; +} rtRdmaSendTaskInfo_t; + +typedef struct tagrtRdmaDbSendTaskInfo { + uint64_t dbInfo; + uint32_t dbIndex; + uint32_t reserved[7]; // offset 7 +} rtRdmaDbSendTaskInfo_t; + +typedef struct tagrtModelEndGraphTaskInfo { + uint32_t modelId; + uint32_t executorFlag; + uint32_t reserved[8]; +} rtModelEndGraphTaskInfo_t; + +typedef struct tagrtModelExitInfo { + uint32_t modelId; + uint32_t streamId; + uint32_t reserved[8]; +} rtModelExitTaskInfo_t; + + +typedef struct tagrtStreamLabelSwitchByIndexTask_t { + uint64_t indexPtr; + uint64_t labelInfoPtr; + uint32_t max; + uint8_t reserved[20]; +} rtStreamLabelSwitchByIndexTask_t; + +typedef struct tagrtStreamLabelGotoTask_t { + uint16_t labelId; + uint16_t modelId; + uint8_t reserved[36]; +} rtStreamLabelGotoTask_t; + +typedef struct tagTaskInfo { + uint32_t type; + uint32_t streamID; + union { + rtKernelTaskInfoEx_t kernelTaskEx; + rtKernelTaskInfo_t kernelTask; + rtEventTaskInfo_t eventTask; + rtStreamSwitchTaskInfo_t streamSwitchTask; + rtStreamActiveTaskInfo_t streamActiveTask; + rtLabelSetTaskInfo_t labelSetTask; + rtLabelSwitchTaskInfo_t labelSwitchTask; + rtLabelGotoTaskInfo_t labelGotoTask; + rtProfilerTrace_t profilertraceTask; + rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; + rtNotifyTaskInfo_t notifyTask; + rtReduceAsyncTaskInfo_t reduceAsyncTask; + rtRdmaSendTaskInfo_t rdmaSendTask; + rtRdmaDbSendTaskInfo_t rdmaDbSendTask; + rtModelEndGraphTaskInfo_t modelEndGraphTask; + rtModelExitTaskInfo_t modelExitTask; + rtStreamSwitchNTaskInfo_t streamSwitchNTask; + rtStreamLabelSwitchByIndexTask_t streamLabelSwitchIndexTask; + rtStreamLabelGotoTask_t streamLabelGotoTask; + uint32_t reserved[10]; + } u; +} rtTaskInfo_t; + +typedef struct tagLabelDevInfo_t { + uint16_t modelId; + uint16_t streamId; + uint16_t labelId; +}rtLabelDevInfo; + +typedef void *rtModel_t; +typedef rtError_t (*rtTaskGenCallback)(rtModel_t model, rtTaskInfo_t *taskInfo); + +/** + * @ingroup rt_model + * @brief set callback for generate model + * @param [in] callBack callback function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback); + +/** + * @ingroup rt_model + * @brief create model instance + * @param [out] model created model + * @param [in] flag reserved + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelCreate(rtModel_t *model, uint32_t flag); + +/** + * @ingroup rt_model + * @brief destroy model instance + * @param [in] model model to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelDestroy(rtModel_t model); + +/** + * @ingroup rt_model + * @brief bind model and stream instance + * @param [in] model binded model + * @param [in] stream binded stream + * @param [in] flag reserved + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelBindStream(rtModel_t model, rtStream_t stream, uint32_t flag); + +/** + * @ingroup rt_model + * @brief unbind model and stream instance + * @param [in] model unbinded model + * @param [in] stream unbinded stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelUnbindStream(rtModel_t model, rtStream_t stream); + +/** + * @ingroup rt_model + * @brief tell runtime Model has been Loaded + * @param [in] model model to execute + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtModelLoadComplete(rtModel_t model); + +/** + * @ingroup rt_model + * @brief execute model instance + * @param [in] model model to execute + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExecute(rtModel_t model, rtStream_t stream, uint32_t flag); + +/** + * @ingroup rt_model + * @brief get model the last persist task id + * @param [in] model model to execute + * @param [out] taskid last task id of the model + * @param [out] streamid last steam id of the model + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid); + +/** + * @ingroup rt_model + * @brief add a end graph task to stream + * @param [in] model model to execute + * @param [in] end graph stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEndGraph(rtModel_t model, rtStream_t stream); + +/** + * @ingroup rt_model + * @brief add a end graph task with flag to stream + * @param [in] model model to execute + * @param [in] end graph stream + * @param [in] flags AICPU datadump + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEndGraphEx(rtModel_t model, rtStream_t stream, uint32_t flags); + +/** + * @ingroup rt_model + * @brief add a end graph task to stream + * @param [in] model model to execute + * @param [in] flags EXECUTOR_TS | EXECUTOR_AICPU + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExecutorSet(rtModel_t model, uint8_t flags); + +/** + * @ingroup rt_model + * @brief abort model + * @param [in] model model to abort + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelAbort(rtModel_t model); + +/** + * @ingroup rt_model + * @brief end graph task to model default stream + * @param [in] model model to execute + * @param [in] end graph stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExit(rtModel_t model, rtStream_t stream); + +/** + * @ingroup rt_model + * @brief bind queue + * @param [in] model model to bind + * @param [in] queueId queueId to bind + * @param [in] flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQueueFlag_t flag); + +/** + * @ingroup rt_model + * @brief get model id + * @param [in] model + * @param [out] modelId model id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); + +/* + * @ingroup rt_model + * @brief enable debug for dump overflow exception + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] model: model handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, + uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception + * @param [in] model: model handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugUnRegister(rtModel_t model); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_MODEL_H__ diff --git a/metadef/third_party/fwkacllib/inc/runtime/stream.h b/metadef/third_party/fwkacllib/inc/runtime/stream.h new file mode 100644 index 00000000..388fd3c2 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/runtime/stream.h @@ -0,0 +1,195 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef __CCE_RUNTIME_STREAM_H__ +#define __CCE_RUNTIME_STREAM_H__ + +#include "base.h" +#include "event.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +/** + * @ingroup stream_flags + * @brief stream op bit flags + */ +#define RT_STREAM_DEFAULT (0x00) +#define RT_STREAM_PERSISTENT (0x01) +#define RT_STREAM_FORCE_COPY (0x02) +#define RT_STREAM_HUGE (0x04) +#define RT_STREAM_AICPU (0x08) +#define RT_STREAM_FORBIDDEN_DEFAULT (0x10) +#define RT_STREAM_HEAD (0x20) +#define RT_STREAM_PRIMARY_DEFAULT (0x40) + +/** + * @ingroup stream_type + * @brief stream type + */ +#define RT_NORMAL_STREAM (0x00) +#define RT_HUGE_STREAM (0x01) + +/** + * priority level default value when create a stream + */ +#define RT_STREAM_PRIORITY_DEFAULT (0) + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stream created stream + * @param [in] priority stream priority + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreate(rtStream_t *stream, int32_t priority); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stream created stream + * @param [in] priority stream priority + * @param [in] flags stream op flags + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreateWithFlags(rtStream_t *stream, int32_t priority, uint32_t flags); + +/** + * @ingroup dvrt_stream + * @brief destroy stream instance. + * @param [in] stream the stream to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamDestroy(rtStream_t stream); + +/** + * @ingroup dvrt_stream + * @brief wait an recorded event for stream + * @param [in] stream the wait stream + * @param [in] event the event to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamWaitEvent(rtStream_t stream, rtEvent_t event); + +/** + * @ingroup dvrt_stream + * @brief wait stream to be complete + * @param [in] stream stream to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamSynchronize(rtStream_t stream); + +/** + * @ingroup dvrt_stream + * @brief queries an asynchronous stream for completion status + * @param [in] stream stream to query + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_STREAM_NOT_COMPLETE for not complete + */ +RTS_API rtError_t rtStreamQuery(rtStream_t stream); + +/** + * @ingroup dvrt_stream + * @brief get stream id from a stream handle + * @param [in] stream stream hadle + * @param [in] streamId stream id + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetStreamId(rtStream_t stream, int32_t *streamId); + +/** + * @ingroup dvrt_stream + * @brief inquire max stream count and max task count per stream + * @param [in] streamType Stream Type + * @param [in] MaxStrCount Max stream count + * @param [in] MaxTaskCount max task count per stream + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *maxStrCount, uint32_t *maxTaskCount); + +/** + * @ingroup dvrt_stream + * @brief Name a stream + * @param [in] stream stream to be named + * @param [in] name identification name + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNameStream(rtStream_t stream, const char *name); + +/** + * @ingroup dvrt_stream + * @brief switch to the corresponding stream according to the contents of the ptr + * @param [in] ptr Determine the address where the value of the true and false branches is located + * @param [in] condition switch condition + * @param [in] value switch value + * @param [in] trueStream Stream that needs to be activated when the value is non-zero + * @param [in] stream input stream to init task + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamSwitch(void *ptr, rtCondition_t condition, int64_t value, rtStream_t trueStream, + rtStream_t stream); + +/** + * @brief execute extensible stream switch task + * @param [in] ptr pointer of value + * @param [in] condition judge condition + * @param [in] value_ptr pointer of target value + * @param [in] true_stream stream to be activated when value is not zero + * @param [in] stream stream id + * @param [in] dataType data type of target value + * @return RT_ERROR_NONE for complete + */ +RTS_API rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *valuePtr, rtStream_t trueStream, + rtStream_t stream, rtSwitchDataType_t dataType); + +/** + * @ingroup dvrt_stream + * @brief Active a stream + * @param [in] activeStream stream to be activated + * @param [in] stream input stream to init task + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamActive(rtStream_t activeStream, rtStream_t stream); + +/** + * @brief execute extensible stream case switch task + * @param [in] ptr pointer of value + * @param [in] size pointer num of value + * @param [in] valuePtr pointer of target value, length = size * elementSize + * @param [in] trueStreamPtr streams to be activated + * @param [in] elementSize size of to be activated true streams + * @param [in] stream input stream to init task + * @param [in] dataType data type of target value + * @return RT_ERROR_NONE for complete + */ +RTS_API rtError_t rtStreamSwitchN(void *ptr, uint32_t size, void *valuePtr, rtStream_t *trueStreamPtr, + uint32_t elementSize, rtStream_t stream, rtSwitchDataType_t dataType); +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif + +#endif // __CCE_RUNTIME_STREAM_H__ diff --git a/metadef/third_party/fwkacllib/inc/toolchain/adx_datadump_server.h b/metadef/third_party/fwkacllib/inc/toolchain/adx_datadump_server.h new file mode 100644 index 00000000..a1c39a51 --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/adx_datadump_server.h @@ -0,0 +1,36 @@ +/** +* @file adx_datadump_server.h +* +* Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ + +#ifndef ADX_DATADUMP_SERVER_H +#define ADX_DATADUMP_SERVER_H +#ifdef __cplusplus +extern "C" { +#endif +/** + * @brief initialize server for normal datadump function. + * @return + * IDE_DAEMON_OK: datadump server init success + * IDE_DAEMON_ERROR: datadump server init failed + */ +int AdxDataDumpServerInit(); + +/** + * @brief uninitialize server for normal datadump function. + * @return + * IDE_DAEMON_OK: datadump server uninit success + * IDE_DAEMON_ERROR: datadump server uninit failed + */ +int AdxDataDumpServerUnInit(); + +#ifdef __cplusplus +} +#endif +#endif + diff --git a/metadef/third_party/fwkacllib/inc/toolchain/prof_callback.h b/metadef/third_party/fwkacllib/inc/toolchain/prof_callback.h new file mode 100644 index 00000000..3fad74bc --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/prof_callback.h @@ -0,0 +1,135 @@ +/** + * Copyright 2020-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file prof_callback.h + * @brief declaraion of profiling callbacks + */ + +#ifndef MSPROFILER_PROF_CALLBACK_H_ +#define MSPROFILER_PROF_CALLBACK_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + + +#include "stddef.h" +#include "stdint.h" + +/** + * @name MsprofErrorCode + * @brief error code + */ +enum MsprofErrorCode { + MSPROF_ERROR_NONE = 0, + MSPROF_ERROR_MEM_NOT_ENOUGH, + MSPROF_ERROR_GET_ENV, + MSPROF_ERROR_CONFIG_INVALID, + MSPROF_ERROR_ACL_JSON_OFF, + MSPROF_ERROR, +}; + +#define MSPROF_ENGINE_MAX_TAG_LEN (31) + +/** + * @name ReporterData + * @brief struct of data to report + */ +struct ReporterData { + char tag[MSPROF_ENGINE_MAX_TAG_LEN + 1]; // the sub-type of the module, data with different tag will be writen + int deviceId; // the index of device + size_t dataLen; // the length of send data + unsigned char *data; // the data content +}; + +/** + * @name MsprofReporterModuleId + * @brief module id of data to report + */ +enum MsprofReporterModuleId { + MSPROF_MODULE_DATA_PREPROCESS = 0, // DATA_PREPROCESS + MSPROF_MODULE_HCCL, // HCCL + MSPROF_MODULE_ACL, // AclModule + MSPROF_MODULE_FRAMEWORK, // Framework + MSPROF_MODULE_RUNTIME // runtime +}; + +/** + * @name MsprofReporterCallbackType + * @brief reporter callback request type + */ +enum MsprofReporterCallbackType { + MSPROF_REPORTER_REPORT = 0, // report data + MSPROF_REPORTER_INIT, // init reporter + MSPROF_REPORTER_UNINIT, // uninit reporter +}; + +/** + * @name MsprofReporterCallback + * @brief callback to start reporter/stop reporter/report date + * @param moduleId [IN] enum MsprofReporterModuleId + * @param type [IN] enum MsprofReporterCallbackType + * @param data [IN] callback data (nullptr on INTI/UNINIT) + * @param len [IN] callback data size (0 on INIT/UNINIT) + * @return enum MsprofErrorCode + */ +typedef int32_t (*MsprofReporterCallback)(uint32_t moduleId, uint32_t type, void *data, uint32_t len); + + +#define MSPROF_OPTIONS_DEF_LEN_MAX (2048) + +/** + * @name MsprofGeOptions + * @brief struct of MSPROF_CTRL_INIT_GE_OPTIONS + */ +struct MsprofGeOptions { + char jobId[MSPROF_OPTIONS_DEF_LEN_MAX]; + char options[MSPROF_OPTIONS_DEF_LEN_MAX]; +}; + +/** + * @name MsprofCtrlCallbackType + * @brief ctrl callback request type + */ +enum MsprofCtrlCallbackType { + MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env + MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json + MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options + MSPROF_CTRL_FINALIZE // stop profiling +}; + +/** + * @name MsprofCtrlCallback + * @brief callback to start/stop profiling + * @param type [IN] enum MsprofCtrlCallbackType + * @param data [IN] callback data + * @param len [IN] callback data size + * @return enum MsprofErrorCode + */ +typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len); + +/** + * @name MsprofSetDeviceCallback + * @brief callback to notify set/reset device + * @param devId [IN] device id + * @param isOpenDevice [IN] true: set device, false: reset device + */ +typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); + +#ifdef __cplusplus +} +#endif + +#endif // MSPROFILER_PROF_CALLBACK_H_ diff --git a/metadef/third_party/fwkacllib/inc/toolchain/prof_engine.h b/metadef/third_party/fwkacllib/inc/toolchain/prof_engine.h new file mode 100644 index 00000000..0e757dcf --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/prof_engine.h @@ -0,0 +1,207 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MSPROF_ENGINE_PROF_ENGINE_H_ +#define MSPROF_ENGINE_PROF_ENGINE_H_ +#define MSVP_PROF_API __attribute__((visibility("default"))) + +#include +#include +#include "prof_reporter.h" + +/** + * @file prof_engine.h + * @defgroup ModuleJobConfig the ModuleJobConfig group + * This is the ModuleJobConfig group + */ +namespace Msprof { +namespace Engine { +/** + * @ingroup ModuleJobConfig + * @brief struct ModuleJobConfig + * record config info + */ +struct ModuleJobConfig { + std::map switches; /**< key is the config name, value is the config value(on or off) */ +}; + +/** + * @defgroup PluginIntf the pluginInf group + * This is the pluginInf group + */ + +/** + * @ingroup PluginIntf + * @brief class PluginIntf + */ +class MSVP_PROF_API PluginIntf { + public: + virtual ~PluginIntf() {} + + public: + /** + * @ingroup PluginIntf + * @name : Init + * @brief : API of user plugin, libmsporf call this API to send a Reporter to user plugin + * @par description : + * API of user plugin, libmsporf call this API to send a Reporter to user plugin. + * @param reporter [IN] const Reporter* the Reporter from libmsprof + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see UnInit + */ + virtual int Init(const Reporter *reporter) = 0; + + /** + * @ingroup PluginIntf + * @name : OnNewConfig + * @brief : API of user plugin, libmsprof call this API to send config info to user plugin \n + If the user plugin needn't config, no need to redefine this function + * @param config [IN] const ModuleJobConfig * the config from libmsprof + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init | UnInit + */ + virtual int OnNewConfig(const ModuleJobConfig *config) { return 0; } + + /** + * @ingroup PluginIntf + * @name : UnInit + * @brief : API of user plugin, libmsprof call this API to notify plugin stop to send data + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init + */ + virtual int UnInit() = 0; +}; + +/** + * @defgroup EngineIntf the EngineIntf group + * This is the EngineIntf group + */ + +/** + * @ingroup EngineIntf + * @brief class EngineIntf + */ +class MSVP_PROF_API EngineIntf { + public: + virtual ~EngineIntf() {} + + public: + /** + * @ingroup EngineIntf + * @name : CreatePlugin + * @brief : API of user engine, libmsporf call this API to get a plugin + * @retval PluginIntf * The pointer of the new plugin + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see ReleasePlugin + */ + virtual PluginIntf *CreatePlugin() = 0; + + /** + * @ingroup EngineIntf + * @name : ReleasePlugin + * @brief : API of user engine, libmsprof call this API to release a plugin + * @param plugin [IN] PluginIntf * the plugin to release + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see CreatePlugin + */ + virtual int ReleasePlugin(PluginIntf *plugin) = 0; +}; + +/** + * @defgroup EngineMgr the EngineMgr group + * This is the EngineMgr group + */ + +/** + * @ingroup EngineMgr + * @name : RegisterEngine + * @brief : API of libmsprof, register an engine with a name + * @param module [IN] const std::string the name of plugin + * @param engine [IN] const EngineIntf* the plugin + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + */ +MSVP_PROF_API int RegisterEngine(const std::string &module, const EngineIntf *engine); + +/** + * @ingroup EngineMgr + * @name : Init + * @brief : API of libmsprof, init an engine with a name + * @param module [IN] const std::string the name of plugin + * @param module [IN] const EngineIntf* the plugin + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see UnInit + */ +MSVP_PROF_API int Init(const std::string &module, const EngineIntf *engine); + +/** + * @ingroup EngineMgr + * @name : Init + * @brief : API of libmsprof, uninit an engine with a name + * @param module [IN] const std::string the name of plugin + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init + */ +MSVP_PROF_API int UnInit(const std::string &module); +} // namespace Engine +} // namespace Msprof + +#endif // MSPROF_ENGINE_PROF_ENGINE_H_ \ No newline at end of file diff --git a/metadef/third_party/fwkacllib/inc/toolchain/prof_mgr_core.h b/metadef/third_party/fwkacllib/inc/toolchain/prof_mgr_core.h new file mode 100644 index 00000000..4f013eef --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/prof_mgr_core.h @@ -0,0 +1,84 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MSPROF_ENGINE_PROF_MGR_CORE_H_ +#define MSPROF_ENGINE_PROF_MGR_CORE_H_ +#define MSVP_PROF_API __attribute__((visibility("default"))) + +#include +#include + +/** + * @file prof_mgr_core.h + * @brief : struct ProfMgrCfg + */ +struct ProfMgrCfg { + std::string startCfg; /**< start cfg. json format */ +}; + +/** + * @name : ProfMgrConf + * @brief : struct ProfMgrConf for example [{"ai_core_events":"0xa"}].the vector size means Number of iterations + */ +struct ProfMgrConf { + std::vector conf; /**< for op trace.Ge call this api to get each iteration profiling cfg.json format.*/ +}; + +/** + * @name : ProfMgrStartUP + * @brief : start Profiling task + * @param cfg [IN]ProfMgrCfg cfg : config of start_up profiling + * @retval void * (success) + * @retval nullptr (failed) + * + * @par depend: + * @li libmsprof + * @li prof_mgr_core.h + * @since c60 + * @see ProfMgrStop + */ +MSVP_PROF_API void *ProfMgrStartUp(const ProfMgrCfg *cfg); + +/** + * @name : ProfMgrStop + * @brief : stop Profiling task + * @param handle [in] void * handle return by ProfMgrStartUP + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_mgr_core.h + * @since c60 + * @see ProfMgrStartUp + */ +MSVP_PROF_API int ProfMgrStop(void *handle); + +/** + * @name : ProfMgrGetConf + * @brief : get profiler events conf + * @param conf [OUT]ProfMgrConf * return by ProfMgrGetConf + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * @par depend: + * @li libmsprof + * @li prof_mgr_core.h + * @since c60 + * @see ProfMgrStartUp + */ +MSVP_PROF_API int ProfMgrGetConf(const std::string &aicoreMetricsType, ProfMgrConf *conf); + +#endif // MSPROF_ENGINE_PROF_MGR_CORE_H_ \ No newline at end of file diff --git a/metadef/third_party/fwkacllib/inc/toolchain/prof_reporter.h b/metadef/third_party/fwkacllib/inc/toolchain/prof_reporter.h new file mode 100644 index 00000000..ff91351b --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/prof_reporter.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MSPROF_ENGINE_PROF_REPORTER_H_ +#define MSPROF_ENGINE_PROF_REPORTER_H_ +#ifndef OS_TYPE +#define OS_TYPE 0 +#endif // OS_TYPE + +#if (OS_TYPE != LINUX) +#define MSVP_PROF_API __declspec(dllexport) +#else +#define MSVP_PROF_API __attribute__((visibility("default"))) +#endif + +#include "prof_callback.h" + +/** + * @file prof_reporter.h + * @defgroup reporter the reporter group + * This is the reporter group + */ +namespace Msprof { +namespace Engine { +/** + * @ingroup reporter + * @brief class Reporter + * the Reporter class .used to send data to profiling + */ +class MSVP_PROF_API Reporter { + public: + virtual ~Reporter() {} + + public: + /** + * @ingroup reporter + * @name : Report + * @brief : API of libmsprof, report data to libmsprof, it's a non-blocking function \n + The data will be firstly appended to cache, if the cache is full, data will be ignored + * @param data [IN] const ReporterData * the data send to libmsporf + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_reporter.h + * @since c60 + * @see Flush + */ + virtual int Report(const ReporterData *data) = 0; + + /** + * @ingroup reporter + * @name : Flush + * @brief : API of libmsprof, notify libmsprof send data over, it's a blocking function \n + The all datas of cache will be write to file or send to host + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_reporter.h + * @since c60 + * @see ProfMgrStop + */ + virtual int Flush() = 0; +}; + +} // namespace Engine +} // namespace Msprof + +#endif // MSPROF_ENGINE_PROF_REPORTER_H_ diff --git a/metadef/third_party/fwkacllib/inc/toolchain/slog.h b/metadef/third_party/fwkacllib/inc/toolchain/slog.h new file mode 100644 index 00000000..5faca0ae --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/slog.h @@ -0,0 +1,397 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef D_SYSLOG_H_ +#define D_SYSLOG_H_ + +#ifdef __cplusplus +#ifndef LOG_CPP +extern "C" { +#endif +#endif // __cplusplus + +#ifndef LINUX +#define LINUX 0 +#endif // LINUX + +#ifndef WIN +#define WIN 1 +#endif + +#ifndef OS_TYPE +#define OS_TYPE 0 +#endif // OS_TYPE + +#if (OS_TYPE == LINUX) +#define DLL_EXPORT __attribute__((visibility("default"))) +#else +#define DLL_EXPORT _declspec(dllexport) +#endif + +/** + * @ingroup slog + * + * debug level id + */ +#define DLOG_DEBUG 0 + +/** + * @ingroup slog + * + * info level id + */ +#define DLOG_INFO 1 + +/** + * @ingroup slog + * + * warning level id + */ +#define DLOG_WARN 2 + +/** + * @ingroup slog + * + * error level id + */ +#define DLOG_ERROR 3 + +/** + * @ingroup slog + * + * don't print log + */ +#define DLOG_NULL 4 + +/** + * @ingroup slog + * + * trace log print level id + */ +#define DLOG_TRACE 5 + +/** + * @ingroup slog + * + * oplog log print level id + */ +#define DLOG_OPLOG 6 + +/** + * @ingroup slog + * + * event log print level id + */ +#define DLOG_EVENT 0x10 + +/** + * @ingroup slog + * + * max log length + */ +#define MSG_LENGTH 1024 +#define DEBUG_LOG_MASK (0x00010000) +#define SECURITY_LOG_MASK (0x00100000) +#define RUN_LOG_MASK (0x01000000) +#define OPERATION_LOG_MASK (0x10000000) +#define RESERVERD_LENGTH 52 + +typedef struct tagDCODE { + const char *cName; + int cVal; +} DCODE; + +typedef struct tagKV { + char *kname; + char *value; +} KeyValue; + +typedef enum { + APPLICATION = 0, + SYSTEM +} ProcessType; + +typedef struct { + ProcessType type; + unsigned int pid; + unsigned int deviceId; + char reserved[RESERVERD_LENGTH]; +} LogAttr; + +/** + * @ingroup slog + * + * module id + */ +enum { + SLOG, /**< Slog */ + IDEDD, /**< IDE daemon device */ + IDEDH, /**< IDE daemon host */ + HCCL, /**< HCCL */ + FMK, /**< Framework */ + HIAIENGINE, /**< Matrix */ + DVPP, /**< DVPP */ + RUNTIME, /**< Runtime */ + CCE, /**< CCE */ +#if (OS_TYPE == LINUX) + HDC, /**< HDC */ +#else + HDCL, +#endif // OS_TYPE + DRV, /**< Driver */ + MDCFUSION, /**< Mdc fusion */ + MDCLOCATION, /**< Mdc location */ + MDCPERCEPTION, /**< Mdc perception */ + MDCFSM, + MDCCOMMON, + MDCMONITOR, + MDCBSWP, /**< MDC base software platform */ + MDCDEFAULT, /**< MDC undefine */ + MDCSC, /**< MDC spatial cognition */ + MDCPNC, + MLL, + DEVMM, /**< Dlog memory managent */ + KERNEL, /**< Kernel */ + LIBMEDIA, /**< Libmedia */ + CCECPU, /**< ai cpu */ + ASCENDDK, /**< AscendDK */ + ROS, /**< ROS */ + HCCP, + ROCE, + TEFUSION, + PROFILING, /**< Profiling */ + DP, /**< Data Preprocess */ + APP, /**< User Application */ + TS, /**< TS module */ + TSDUMP, /**< TSDUMP module */ + AICPU, /**< AICPU module */ + LP, /**< LP module */ + TDT, + FE, + MD, + MB, + ME, + IMU, + IMP, + GE, /**< Fmk */ + MDCFUSA, + CAMERA, + ASCENDCL, + TEEOS, + ISP, + SIS, + HSM, + DSS, + PROCMGR, // Process Manager, Base Platform + BBOX, + AIVECTOR, + TBE, + FV, + MDCMAP, + TUNE, + INVLID_MOUDLE_ID +}; + +/** + * @ingroup slog + * @brief External log interface, which called by modules + */ +DLL_EXPORT void dlog_init(void); + +/** + * @ingroup slog + * @brief dlog_getlevel: get module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), others: invalid + * @param [out]enableEvent: 1: enable; 0: disable + * @return: module level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + */ +DLL_EXPORT int dlog_getlevel(int moduleId, int *enableEvent); + +/** + * @ingroup slog + * @brief dlog_setlevel: set module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), -1: all modules, others: invalid + * @param [in]level: log level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + * @param [in]enableEvent: 1: enable; 0: disable, others:invalid + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int dlog_setlevel(int moduleId, int level, int enableEvent); + +/** + * @ingroup slog + * @brief CheckLogLevel: check module level enable or not + * users no need to call it because all dlog interface(include inner interface) has already called + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG + * @return: 1:enable, 0:disable + */ +DLL_EXPORT int CheckLogLevel(int moduleId, int logLevel); + +/** + * @ingroup slog + * @brief DlogSetAttr: set log attr, default pid is 0, default device id is 0, default process type is APPLICATION + * @param [in]logAttr: attr info, include pid(must be larger than 0), process type and device id(chip ID) + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogSetAttr(LogAttr logAttr); + +/** + * @ingroup slog + * @brief dlog_error: print error log + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_error(moduleId, fmt, ...) \ + do { \ + DlogErrorInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } while (0) + +/** + * @ingroup slog + * @brief dlog_warn: print warning log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_warn(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ + DlogWarnInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief dlog_info: print info log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_info(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ + DlogInfoInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief dlog_debug: print debug log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_debug(moduleId, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ + DlogDebugInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief dlog_event: print event log + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_event(moduleId, fmt, ...) \ + do { \ + DlogEventInner(moduleId, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } while (0) + +/** + * @ingroup slog + * @brief Dlog: print log, need caller to specify level + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]fmt: log content + */ +#define Dlog(moduleId, level, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogInner(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogSub: print log, need caller to specify level and submodule + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]submodule: eg: engine + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]fmt: log content + */ +#define DlogSub(moduleId, submodule, level, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogInner(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogWithKV: print log, need caller to specify level and other paramters + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]pstKVArray: key-value array + * @param [in]kvNum: key-value element num in array + * @param [in]fmt: log content + */ +#define DlogWithKV(moduleId, level, pstKVArray, kvNum, fmt, ...) \ + do { \ + if(CheckLogLevel(moduleId, level) == 1) { \ + DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogFlush: flush log buffer to file + */ +DLL_EXPORT void DlogFlush(void); + +/** + * @ingroup slog + * @brief Internal log interface, other modules are not allowed to call this interface + */ +void DlogErrorInner(int moduleId, const char *fmt, ...); +void DlogWarnInner(int moduleId, const char *fmt, ...); +void DlogInfoInner(int moduleId, const char *fmt, ...); +void DlogDebugInner(int moduleId, const char *fmt, ...); +void DlogEventInner(int moduleId, const char *fmt, ...); +void DlogInner(int moduleId, int level, const char *fmt, ...); +void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); + +#ifdef __cplusplus +#ifndef LOG_CPP +} +#endif // LOG_CPP +#endif // __cplusplus +#endif // D_SYSLOG_H_ diff --git a/metadef/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h b/metadef/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h new file mode 100644 index 00000000..12b6aa1e --- /dev/null +++ b/metadef/third_party/fwkacllib/inc/toolchain/tuning_tool/tune_api.h @@ -0,0 +1,72 @@ +/** + * @file tune_api.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2020. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * 描述:mstune调优接口头文件 + */ +/** @defgroup mstune mstune调优接口 */ +#ifndef TUNE_API_H +#define TUNE_API_H +#include +#include +#include +#include "graph/graph.h" +#include "ge/ge_api.h" + +/** + * @ingroup mstune + * + * mstune status + */ +enum MsTuneStatus { + MSTUNE_SUCCESS, /** tune success */ + MSTUNE_FAILED, /** tune failed */ +}; + +// Option key: for train options sets +const std::string MSTUNE_SELF_KEY = "mstune"; +const std::string MSTUNE_GEINIT_KEY = "initialize"; +const std::string MSTUNE_GESESS_KEY = "session"; + +/** + * @ingroup mstune + * @par 描述: 命令行调优 + * + * @attention 无 + * @param option [IN] 调优参数 + * @param msg [OUT] 调优异常下返回信息 + * @retval #MSTUNE_SUCCESS 执行成功 + * @retval #MSTUNE_FAILED 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +MsTuneStatus MsTuning(const std::map &option, std::string &msg); + +/** + * @ingroup mstune + * @par 描述: 梯度调优 + * + * @attention 无 + * @param tuningGraph [IN] 调优图 + * @param dependGraph [IN] 调优依赖图 + * @param session [IN] ge连接会话 + * @param option [IN] 参数集. 包含调优参数及ge参数 + * @retval #MSTUNE_SUCCESS 执行成功 + * @retval #MSTUNE_FAILED 执行失败 + * @par 依赖: + * @li tune_api.cpp:该接口所属的开发包。 + * @li tune_api.h:该接口声明所在的头文件。 + * @see 无 + * @since + */ +extern "C" MsTuneStatus MsTrainTuning(ge::Graph &tuningGraph, std::vector &dependGraph, + ge::Session *session, const std::map> &option); + +#endif diff --git a/metadef/third_party/graphengine/ge/common/singleton.h b/metadef/third_party/graphengine/ge/common/singleton.h new file mode 100644 index 00000000..314e824e --- /dev/null +++ b/metadef/third_party/graphengine/ge/common/singleton.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GE_COMMON_SINGLETON_H_ +#define GE_COMMON_SINGLETON_H_ + +#include + +#define DECLARE_SINGLETON_CLASS(T) friend class Singleton + +namespace ge { +static std::mutex single_mutex_; +// Single thread version single instance template +template +class Singleton { + public: + Singleton(Singleton const &) = delete; + Singleton &operator=(Singleton const &) = delete; + + template + static T *Instance(_Args... args) { + std::lock_guard lock(single_mutex_); + if (instance_ == nullptr) { + // std::nothrow, Nullptr returned when memory request failed + instance_.reset(new (std::nothrow) T(args...)); + } + return instance_.get(); + } + + static void Destroy(void) { instance_.reset(); } + + Singleton() = default; + virtual ~Singleton() = default; + + private: + static std::unique_ptr instance_; +}; + +template +std::unique_ptr Singleton::instance_; +} // namespace ge +#endif // GE_COMMON_SINGLETON_H_ diff --git a/metadef/third_party/graphengine/ge/graph/common/omg_util.h b/metadef/third_party/graphengine/ge/graph/common/omg_util.h new file mode 100644 index 00000000..1f93c92b --- /dev/null +++ b/metadef/third_party/graphengine/ge/graph/common/omg_util.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_COMMON_OMG_UTIL_H_ +#define GE_GRAPH_COMMON_OMG_UTIL_H_ + +#include +#include +#include +#include + +#include "common/types.h" +#include "common/util.h" +#include "graph/node.h" + +namespace ge { +/// +/// @brief get the Original Type of FrameworkOp +/// @param [in] node +/// @param [out] type +/// @return Status +/// +Status GetOriginalType(const ge::NodePtr &node, string &type); + +/// +/// @brief set op stream_label +/// @param [in] node +/// @param [in] label +/// @return Status +/// +Status SetStreamLabel(const ge::NodePtr &node, const std::string &label); + +/// +/// @brief set op cycle_event flag +/// @param [in] node +/// @return Status +/// +Status SetCycleEvent(const ge::NodePtr &node); + +/// +/// @brief set op active_label_list +/// @param [in] node +/// @param [in] label +/// @return Status +/// +Status SetActiveLabelList(const ge::NodePtr &node, const std::vector &active_label_list); + +/// +/// @brief set op branch_label +/// @param [in] node +/// @param [in] branch_label +/// @return Status +/// +Status SetSwitchBranchNodeLabel(const ge::NodePtr &node, const std::string &branch_label); + +/// +/// @brief set op true_branch flag +/// @param [in] node +/// @param [in] value +/// @return Status +/// +Status SetSwitchTrueBranchFlag(const ge::NodePtr &node, bool value); + +/// +/// @brief set op original name +/// @param [in] node +/// @param [in] orig_name +/// @return Status +/// +Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name); + +/// +/// @brief set op cyclic_dependence flag +/// @param [in] node +/// @return Status +/// +Status SetCyclicDependenceFlag(const ge::NodePtr &node); + +/// +/// @brief set op next_iteration name +/// @param [in] node +/// @param [in] next +/// @return Status +/// +Status SetNextIteration(const ge::NodePtr &node, const std::string &next); +} // namespace ge + +#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ diff --git a/metadef/third_party/graphengine/ge/graph/optimize/common/params.h b/metadef/third_party/graphengine/ge/graph/optimize/common/params.h new file mode 100644 index 00000000..c174a4d1 --- /dev/null +++ b/metadef/third_party/graphengine/ge/graph/optimize/common/params.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_OPTIMIZE_COMMON_PARAMS_H_ +#define GE_GRAPH_OPTIMIZE_COMMON_PARAMS_H_ + +#include + +#include "common/singleton.h" +#include "common/types.h" + +namespace ge { +class Params : public Singleton { + public: + DECLARE_SINGLETON_CLASS(Params); + + void SetTarget(const char* target) { + std::string tmp_target = (target != nullptr) ? target : ""; + +#if defined(__ANDROID__) || defined(ANDROID) + target_ = "LITE"; + target_8bit_ = TARGET_TYPE_LTTE_8BIT; +#else + target_ = "MINI"; + target_8bit_ = TARGET_TYPE_MINI_8BIT; +#endif + if (tmp_target == "mini") { + target_ = "MINI"; + target_8bit_ = TARGET_TYPE_MINI_8BIT; + } else if (tmp_target == "lite") { + target_ = "LITE"; + target_8bit_ = TARGET_TYPE_LTTE_8BIT; + } + } + + string GetTarget() const { return target_; } + + uint8_t GetTarget_8bit() const { return target_8bit_; } + ~Params() override = default; + + private: + Params() : target_("MINI") {} + + string target_; + uint8_t target_8bit_ = 0; +}; +} // namespace ge + +#endif // GE_GRAPH_OPTIMIZE_COMMON_PARAMS_H_ diff --git a/metadef/third_party/graphengine/ge/graph/passes/variable_format_pass.h b/metadef/third_party/graphengine/ge/graph/passes/variable_format_pass.h new file mode 100644 index 00000000..e2c32903 --- /dev/null +++ b/metadef/third_party/graphengine/ge/graph/passes/variable_format_pass.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ +#define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ + +#include +#include +#include +#include "graph/types.h" +#include "graph/utils/op_desc_utils.h" +#include "inc/graph_pass.h" + +namespace ge { +class VariableFormatPass : public GraphPass { + public: + Status Run(ge::ComputeGraphPtr graph) override; + + private: + bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); + + bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, + const map > &confirm_ops, ge::NodePtr &use_node); + + Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); + + Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ diff --git a/metadef/third_party/graphengine/ge/inc/graph_pass.h b/metadef/third_party/graphengine/ge/inc/graph_pass.h new file mode 100644 index 00000000..642b94ea --- /dev/null +++ b/metadef/third_party/graphengine/ge/inc/graph_pass.h @@ -0,0 +1,93 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_INC_GRAPH_PASS_H_ +#define GE_INC_GRAPH_PASS_H_ + +#include +#include + +#include "common/op/attr_value_util.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/compute_graph.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "inc/pass.h" + +namespace ge { +/// +/// @ingroup domi_omg +/// @brief graph pass +/// @author +/// +class GraphPass : public Pass { + public: + /// + /// run graph pass + /// @param [in] graph graph to be optimized + /// @return SUCCESS optimize successfully + /// @return NOT_CHANGED not optimized + /// @return others optimized failed + /// @author + /// + virtual Status Run(ge::ComputeGraphPtr graph) = 0; + virtual Status ClearStatus() { return SUCCESS; }; + static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { + GE_CHECK_NOTNULL_JUST_RETURN(node); + std::vector original_names; + for (ge::NodePtr &node_tmp : original_nodes) { + std::vector names_tmp; + ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); + GE_CHECK_NOTNULL_JUST_RETURN(opdesc_tmp); + Status ret = ge::AttrUtils::GetListStr(opdesc_tmp, "_datadump_original_op_names", names_tmp); + if (ret != domi::SUCCESS) { + GELOGW("get the original_op_names fail."); + } + if (names_tmp.size() != 0) { + original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); + } else { + original_names.push_back(opdesc_tmp->GetName()); + } + } + + if (original_names.size() == 0) { + std::string tmp; + original_names.push_back(tmp); + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), "_datadump_original_op_names", original_names), + return, "Set original_op_names fail."); + } + + static bool IsConstNode(const ge::NodePtr &node) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, GELOGE(FAILED, "Node GetOpDesc is nullptr"); return false); + if (node->GetOpDesc()->GetType() == CONSTANTOP) { + return true; + } else if (node->GetOpDesc()->GetType() == FRAMEWORKOP) { + string type; + GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type), + return false, "Get original_type for op %s fail!", node->GetName().c_str()); + GE_IF_BOOL_EXEC(type == CONSTANT, GELOGI("Is const op"); return true); + return false; + } else { + return false; + } + } +}; +} // namespace ge + +#endif // GE_INC_GRAPH_PASS_H_ diff --git a/metadef/third_party/graphengine/ge/inc/pass.h b/metadef/third_party/graphengine/ge/inc/pass.h new file mode 100644 index 00000000..9f8519e1 --- /dev/null +++ b/metadef/third_party/graphengine/ge/inc/pass.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_INC_PASS_H_ +#define GE_INC_PASS_H_ + +#include + +#include "common/fmk_error_codes.h" + +namespace ge { +/// +/// @ingroup domi_omg +/// @brief pass +/// @author +/// +template +class Pass { + public: + virtual ~Pass() {} + /// + /// run pass + /// @author + /// + virtual Status Run(std::shared_ptr) = 0; +}; +} // namespace ge + +#endif // GE_INC_PASS_H_ diff --git a/metadef/third_party/graphengine/inc/external/ge/ge_api_error_codes.h b/metadef/third_party/graphengine/inc/external/ge/ge_api_error_codes.h new file mode 100644 index 00000000..67f5fa05 --- /dev/null +++ b/metadef/third_party/graphengine/inc/external/ge/ge_api_error_codes.h @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ + +#include +#include +#include "ge_error_codes.h" + +namespace ge { +#ifdef __GNUC__ +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif +class StatusFactory { + public: + static StatusFactory *Instance() { + static StatusFactory instance; + return &instance; + } + + void RegisterErrorNo(uint32_t err, const std::string &desc) { + // Avoid repeated addition + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = desc; + } + + void RegisterErrorNo(uint32_t err, const char *desc) { + if (desc == nullptr) { + return; + } + std::string error_desc = desc; + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = error_desc; + } + + std::string GetErrDesc(uint32_t err) { + auto iter_find = err_desc_.find(err); + if (iter_find == err_desc_.end()) { + return ""; + } + return iter_find->second; + } + + protected: + StatusFactory() {} + ~StatusFactory() {} + + private: + std::map err_desc_; +}; + +class ErrorNoRegisterar { + public: + ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } + ErrorNoRegisterar(uint32_t err, const char *desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } + ~ErrorNoRegisterar() {} +}; + +// Code compose(4 byte), runtime: 2 bit, type: 2 bit, level: 3 bit, sysid: 8 bit, modid: 5 bit, value: 12 bit +#define GE_ERRORNO(runtime, type, level, sysid, modid, name, value, desc) \ + constexpr ge::Status name = \ + ((0xFF & (static_cast(runtime))) << 30) | ((0xFF & (static_cast(type))) << 28) | \ + ((0xFF & (static_cast(level))) << 25) | ((0xFF & (static_cast(sysid))) << 17) | \ + ((0xFF & (static_cast(modid))) << 12) | (0x0FFF & (static_cast(value))); \ + const ErrorNoRegisterar g_##name##_errorno(name, desc); + +#define GE_ERRORNO_EXTERNAL(name, desc) \ + const ErrorNoRegisterar g_##name##_errorno(name, desc); + +using Status = uint32_t; + +// General error code +GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); +GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/ + +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_PARAM_INVALID, "Parameter invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_NOT_INIT, "GE executor not initialized yet."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "Model file path invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_ID_INVALID, "Model id invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_KEY_PATH_INVALID, "Model key path invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, "Model does not support encryption."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "Data size of model invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID, "Model addr invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Queue id of model invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_LOAD_MODEL_REPEATED, "The model loaded repeatedly."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Model partition num invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID, "Dynamic input addr invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID, "Dynamic input size invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID, "Dynamic batch size invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_AIPP_BATCH_EMPTY, "AIPP batch parameter empty."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_AIPP_NOT_EXIST, "AIPP parameter not exist."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_AIPP_MODE_INVALID, "AIPP mode invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_TASK_TYPE_INVALID, "Task type invalid."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Kernel type invalid."); + +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_MEMORY_ALLOCATION, "Memory allocation error."); + +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_INTERNAL_ERROR, "Internal error."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_LOAD_MODEL, "Load model error."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_LOAD_MODEL_PARTITION_FAILED, "Failed to load model partition."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, "Failed to load weight partition."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_LOAD_TASK_PARTITION_FAILED, "Failed to load task partition."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, "Failed to load op kernel partition."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_EXEC_RELEASE_MODEL_DATA, "Failed to release the model data."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_COMMAND_HANDLE, "Command handle error."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_GET_TENSOR_INFO, "Get tensor info error."); +GE_ERRORNO_EXTERNAL(ACL_ERROR_GE_UNLOAD_MODEL, "Load model error."); + +} // namespace ge + +#endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ diff --git a/metadef/third_party/graphengine/inc/external/ge/ge_api_types.h b/metadef/third_party/graphengine/inc/external/ge/ge_api_types.h new file mode 100644 index 00000000..4d80ab13 --- /dev/null +++ b/metadef/third_party/graphengine/inc/external/ge/ge_api_types.h @@ -0,0 +1,445 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GE_GE_API_TYPES_H_ +#define INC_EXTERNAL_GE_GE_API_TYPES_H_ + +#include +#include +#include +#include +#include +#include + +namespace ge +{ +// Option key: graph run mode +const char *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; + +// Option key: ome init +const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; +const char *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; +const char *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; +const char *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; +const char *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd"; +const char *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; +const char *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; +const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; +const char *const OPTION_EXEC_RANK_TABLE_FILE = "ge.exec.rankTableFile"; +const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; +const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; +// Dump flag and para +const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; +const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; +const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; +const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; +const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; +const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; +const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +const char *const OPTION_EXEC_ENABLE_EXCEPTION_DUMP = "ge.exec.enable_exception_dump"; +const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; +const char *const OPTION_EXEC_PROFILING_FPPONIT_OPTIONS = "ge.exec.profilingFpPointOptions"; +const char *const OPTION_EXEC_PROFILING_BPPONIT_OPTIONS = "ge.exec.profilingBpPointOptions"; +// profiling flag +const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; +const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; +// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 +const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; +const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; +const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; +const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization"; + +// Option key: memory init +const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; +const char *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; +namespace configure_option { +const char *const STREAM_NUM = "ge.streamNum"; +const char *const HEAD_STREAM = "ge.headStream"; +const char *const PERF_LEVEL = "ge.perfLevel"; +const char *const ENCRYPT_MODE = "ge.encryptMode"; +const char *const EK_FILE = "ge.ekFile"; +const char *const CERT_FILE = "ge.certFile"; +const char *const HW_KEY_FILE = "ge.hwKeyFile"; +const char *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; +const char *const FRAMEWORK_TYPE = "ge.frameworkType"; +const char *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; +const char *const INSERT_OP_FILE = "ge.insertOpFile"; +const char *const OUTPUT_NODE_NAME = "ge.outputNodeName"; +const char *const COMPRESS_FLAG = "ge.compressFlag"; +const char *const PRECISION_MODE = "ge.exec.precision_mode"; +const char *const SINGLE_OP_FLAG = "ge.exec.single_op"; +const char *const TRAIN_FLAG = "ge.trainFlag"; +const char *const RUN_FLAG = "ge.runFlag"; +const char *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; +const char *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; +const char *const DDK_VERSION_FLAG = "ge.DDK_version"; +const char *const GE_FE_FLAG = "ge.feFlag"; +const char *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; +const char *const OUTPUT_DATATYPE = "ge.outputDatatype"; +const char *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; +const char *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; +const char *const HCOM_PARALLEL = "ge.hcomParallel"; +const char *const AUTO_TUNE_MODE = "ge.autoTuneMode"; +const char *const SOC_VERSION = "ge.socVersion"; +const char *const CORE_TYPE = "ge.engineType"; +const char *const AICORE_NUM = "ge.aicoreNum"; +const char *const L1_FUSION = "ge.l1Fusion"; +const char *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; +const char *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; +const char *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; +const char *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; +const char *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; +const char *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; +const char *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +const char *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; +} + +// Configure stream num by Session constructor options param, +// its value should be int32_t type, default value is "1" +const std::string STREAM_NUM = "ge.streamNum"; + +// Configure add head stream to model. +// its value should be "0" or "1", default value is "0" +const std::string HEAD_STREAM = "ge.headStream"; + +// Configure perf level by Session constructor options param, +// its value please see enum PerfLevel, default value is "4" +const std::string PERF_LEVEL = "ge.perfLevel"; + +// Configure encrypt mode by Session constructor options param, +// its value should be int32_t type, default value is "-1" +const std::string ENCRYPT_MODE = "ge.encryptMode"; + +// configure ek file by Session constructor options param, +// its value should be file path, default value is "" +const std::string EK_FILE = "ge.ekFile"; + +// Configure cert file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CERT_FILE = "ge.certFile"; + +// Configure hw key file by Session constructor options param, +// its value should be file path, default value is "" +const std::string HW_KEY_FILE = "ge.hwKeyFile"; + +// Configure private file by Session constructor options param, +// its value should be file path, default value is "" +const std::string PRIVATE_KEY_FILE = "ge.privateKeyFile"; + +// Configure framework type by Session constructor options param, +// its value please see enum FrameworkType, default value is "3" +const std::string FRAMEWORK_TYPE = "ge.frameworkType"; + +// Configure calibration info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; + +// Configure insert op info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string INSERT_OP_FILE = "ge.insertOpFile"; + +// Configure output node name by Session constructor options param, +// its value should be std::string type, default value is "" +const std::string OUTPUT_NODE_NAME = "ge.outputNodeName"; + +// Configure weight compress flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string COMPRESS_FLAG = "ge.compressFlag"; + +const std::string PRECISION_MODE = "ge.exec.precision_mode"; + +// Configure single op flag for FE +// its value should be "0" or "1", default value is "0" +const std::string SINGLE_OP_FLAG = "ge.exec.single_op"; + +// Configure train flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string TRAIN_FLAG = "ge.trainFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string RUN_FLAG = "ge.runFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +// this option is to enable local framework op feature +const std::string LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain the TBE op plugin path +const std::string TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain the DDK Version info +const std::string DDK_VERSION_FLAG = "ge.DDK_version"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain fe flag +const std::string GE_FE_FLAG = "ge.feFlag"; + +// Configure stream max parallel num only by Session constructor options param, +// its value should be stream:int, such as "DNN_V100:2,DNN_HCCL:3", +// default value is "1", such as "DNN_V100:1,DNN_HCCL:1" +// this option is to obtain stream max parallel num +const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; + +// congigure outputDatatype to setting net output type +const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; + +// congigure opSelectImplmode to setting op select implmode +const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; + +// congigure optypelist_for_implmode to setting which op use implmode +const std::string OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; + +// configure whether to enable hcom parallel by session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string HCOM_PARALLEL = "ge.hcomParallel"; + +// configure whether to use dynamic batch size +const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; + +const std::string INPUT_SHAPE = "ge.inputShape"; + +const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; +// configure whether to use dynamic image size +const char *const kDynamicImageSize = "ge.dynamicImageSize"; + +// Configure whether to use dynamic dims +const char *const kDynamicDims = "ge.dynamicDims"; + +// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, +// example: GA|RL, support configure multiple, split by | +const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; + +// Configure soc version , example: "Ascend310" +const std::string SOC_VERSION = "ge.socVersion"; + +// Configure core type "VectorEngine", default value is "AIcoreEngine" +const std::string CORE_TYPE = "ge.engineType"; + +// Configure AICORE NUM +const std::string AICORE_NUM = "ge.aicoreNum"; + +// Configure L1FUSION +const std::string L1_FUSION = "ge.l1Fusion"; + +// Configure l1,l2,and others optimize option +const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; + +// Configure Small Channel flag +const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; + +// Configure Compress Weight flag +const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; + +// Configure fusion switch file path +const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; + +// Save original model +const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; + +// Save original model file name +const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; + +const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; +const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; +const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; + +// Configure for print op pass +// Its value should be "0" or "1", default value is "1" +const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; + +// Configure operator compilation path +// Its value should be file path, default value is "./" +const char *const DEBUG_DIR = "ge.debugDir"; + +// Configure operator compiler cache path +// Its value should be file path, default value is "./" +const char *const OP_COMPILER_CACHE_DIR = "ge.op_compiler_cache_dir"; + +// Configure operator compiler cache mode +// Its value should be "disable", "enable" or "force", default value is "disable" +const char *const OP_COMPILER_CACHE_MODE = "ge.op_compiler_cache_mode"; + +// Configure whether to use single stream. +// Its value should be "true" or "false", default value is "false" +const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; + +// Configure input fp16 nodes +const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; + +// Configure debug level, its value should be 0(default), 1 or 2. +// 0: close debug; 1: open TBE compiler; 2: open ccec compiler +const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; + +// Configure model bank path +const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; + +// Configure op bank path +const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; + +// Graph run mode +enum GraphRunMode +{ + PREDICTION = 0, + TRAIN +}; + +// Input/Output tensor info +struct InputTensorInfo +{ + uint32_t data_type; // data type + std::vector dims; // shape description + void *data; // tensor data + int64_t length; // tensor length +}; + +struct OutputTensorInfo +{ + uint32_t data_type; // data type + std::vector dims; // shape description + std::unique_ptr data; // tensor data + int64_t length; // tensor length + OutputTensorInfo() : data_type(0), dims({}), data(nullptr), length(0) + { + } + OutputTensorInfo(OutputTensorInfo &&out) + : data_type(out.data_type), dims(out.dims), data(std::move(out.data)), length(out.length) + { + } + + OutputTensorInfo &operator=(OutputTensorInfo &&out) + { + if (this != &out) + { + data_type = out.data_type; + dims = out.dims; + data = std::move(out.data); + length = out.length; + } + return *this; + } + OutputTensorInfo(const OutputTensorInfo &) = delete; + OutputTensorInfo &operator=(const OutputTensorInfo &) = delete; +}; + +using Status = uint32_t; +using RunAsyncCallback = std::function &)>; +// for ir build +namespace ir_option +{ +static const char *const INPUT_FORMAT = "input_format"; +static const char *const INPUT_SHAPE = "input_shape"; +static const char *const OP_NAME_MAP = "op_name_map"; +static const char *const IS_DYNAMIC_INPUT = "is_dynamic_input"; +static const char *const IS_INPUT_ADJUST_HW_LAYOUT = "is_input_adjust_hw_layout"; +static const char *const IS_OUTPUT_ADJUST_HW_LAYOUT = "is_output_adjust_hw_layout"; +static const char *const ENABLE_SCOPE_FUSION_PASSES = "enable_scope_fusion_passes"; +static const char *const OUTPUT = "output"; +static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; +static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; +static const char *const DYNAMIC_DIMS = kDynamicDims; +static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); +static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); +static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; +static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); +static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); +static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); +static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; +static const char *const AICORE_NUM = ge::AICORE_NUM.c_str(); +static const char *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str(); +static const char *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str(); +static const char *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str(); +static const char *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str(); +static const char *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str(); +static const char *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str(); +static const char *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +static const char *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); +static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); +static const char *const LOG_LEVEL = "log"; +static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c_str(); +static const char *const DEBUG_DIR = ge::DEBUG_DIR; +static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; +static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; +static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str(); +static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str(); +static const char *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); + +// for interface: aclgrphBuildModel +const std::set ir_builder_suppported_options = {INPUT_FORMAT, + INPUT_SHAPE, + OP_NAME_MAP, + DYNAMIC_BATCH_SIZE, + DYNAMIC_IMAGE_SIZE, + DYNAMIC_DIMS, + INSERT_OP_FILE, + PRECISION_MODE, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + OUTPUT_TYPE, + OUT_NODES, + INPUT_FP16_NODES, + LOG_LEVEL, + OP_DEBUG_LEVEL, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE, + MDL_BANK_PATH_FLAG, + OP_BANK_PATH_FLAG}; + +// for interface: aclgrphParse +const std::set ir_parser_suppported_options = {INPUT_FORMAT, + INPUT_SHAPE, + OP_NAME_MAP, + IS_DYNAMIC_INPUT, + INPUT_FP16_NODES, + IS_INPUT_ADJUST_HW_LAYOUT, + IS_OUTPUT_ADJUST_HW_LAYOUT, + OUTPUT, + OUTPUT_TYPE, + OUT_NODES, + COMPRESS_WEIGHT_CONF, + ENABLE_SCOPE_FUSION_PASSES, + LOG_LEVEL}; + +// for interface: aclgrphBuildInitialize +const std::set global_options = {CORE_TYPE, + SOC_VERSION, + BUFFER_OPTIMIZE, + ENABLE_COMPRESS_WEIGHT, + COMPRESS_WEIGHT_CONF, + PRECISION_MODE, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + ENABLE_SINGLE_STREAM, + AICORE_NUM, + FUSION_SWITCH_FILE, + ENABLE_SMALL_CHANNEL, + OP_SELECT_IMPL_MODE, + OPTYPELIST_FOR_IMPLMODE, + OP_DEBUG_LEVEL, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE}; +} // namespace ir_option +} // namespace ge + +#endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/external/ge/ge_error_codes.h b/metadef/third_party/graphengine/inc/external/ge/ge_error_codes.h new file mode 100644 index 00000000..30631d8b --- /dev/null +++ b/metadef/third_party/graphengine/inc/external/ge/ge_error_codes.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_GE_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_GE_ERROR_CODES_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif +static const uint32_t ACL_ERROR_GE_PARAM_INVALID = 145000; +static const uint32_t ACL_ERROR_GE_EXEC_NOT_INIT = 145001; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID = 145002; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ID_INVALID = 145003; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_KEY_PATH_INVALID = 145004; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION = 145005; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID = 145006; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID = 145007; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID = 145008; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_REPEATED = 145009; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_PARTITION_NUM_INVALID = 145010; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID = 145011; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID = 145012; +static const uint32_t ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID = 145013; +static const uint32_t ACL_ERROR_GE_AIPP_BATCH_EMPTY = 145014; +static const uint32_t ACL_ERROR_GE_AIPP_NOT_EXIST = 145015; +static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016; +static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017; +static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018; +static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000; +static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000; +static const uint32_t ACL_ERROR_GE_LOAD_MODEL = 545001; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_PARTITION_FAILED = 545002; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED = 545003; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_TASK_PARTITION_FAILED = 545004; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_KERNEL_PARTITION_FAILED = 545005; +static const uint32_t ACL_ERROR_GE_EXEC_RELEASE_MODEL_DATA = 545006; +static const uint32_t ACL_ERROR_GE_COMMAND_HANDLE = 545007; +static const uint32_t ACL_ERROR_GE_GET_TENSOR_INFO = 545008; +static const uint32_t ACL_ERROR_GE_UNLOAD_MODEL = 545009; +#ifdef __cplusplus +} // namespace ge +#endif +#endif // INC_EXTERNAL_GE_GE_ERROR_CODES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/debug/ge_log.h b/metadef/third_party/graphengine/inc/framework/common/debug/ge_log.h new file mode 100644 index 00000000..4a32af36 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/debug/ge_log.h @@ -0,0 +1,121 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ +#define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ + +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "toolchain/slog.h" +#ifdef __GNUC__ +#include +#include +#else +#include "mmpa/mmpa_api.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GE_MODULE_NAME static_cast(GE) + +// trace status of log +enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; + +class GeLog { +public: +#ifdef __GNUC__ +static pid_t GetTid() { + thread_local static pid_t tid = syscall(__NR_gettid); + return tid; +} +#else +static int GetTid() { + thread_local static int tid = static_cast(GetCurrentThreadId()); + return tid; +} +#endif +}; + +inline bool IsLogEnable(int module_name, int log_level) { + int32_t enable = CheckLogLevel(module_name, log_level); + // 1:enable, 0:disable + if (enable == 1) { + return true; + } + return false; +} + +#define GELOGE(ERROR_CODE, fmt, ...) \ + dlog_error(GE_MODULE_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GeLog::GetTid(), __FUNCTION__, ERROR_CODE, \ + ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) +#define GELOGW(fmt, ...) \ + if (IsLogEnable(GE_MODULE_NAME, DLOG_WARN)) dlog_warn(GE_MODULE_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GELOGI(fmt, ...) \ + if (IsLogEnable(GE_MODULE_NAME, DLOG_INFO)) dlog_info(GE_MODULE_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GELOGD(fmt, ...) \ + if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) dlog_debug(GE_MODULE_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GEEVENT(fmt, ...) dlog_event(GE_MODULE_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GELOGO(fmt, ...) \ + Dlog(GE_MODULE_NAME, DLOG_OPLOG, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GELOGT(VALUE, fmt, ...) \ + do { \ + TraceStatus stat = VALUE; \ + const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ + int idx = static_cast(stat); \ + char *k = const_cast("status"); \ + char *v = const_cast(TraceStatStr[idx]); \ + KeyValue kv = {k, v}; \ + DlogWithKV(static_cast(GE_MODULE_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__); \ + } while (0) + +#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ + dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GeLog::GetTid(), __FUNCTION__, ERROR_CODE, \ + ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) +#define GE_LOG_WARN(MOD_NAME, fmt, ...) \ + if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GE_LOG_INFO(MOD_NAME, fmt, ...) \ + if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ + if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) +#define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ + Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__) + +#define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ + do { \ + TraceStatus stat = value; \ + const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ + int idx = static_cast(stat); \ + char *k = const_cast("status"); \ + char *v = const_cast(TraceStatStr[idx]); \ + KeyValue kv = {k, v}; \ + DlogWithKV(static_cast(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GeLog::GetTid(), __FUNCTION__, ##__VA_ARGS__); \ + } while (0) + +// print memory when it is greater than 1KB. +#define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ + do { \ + if ((SIZE) > 1024) { \ + GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast(SIZE), (PURPOSE)); \ + } \ + } while (0); +#ifdef __cplusplus +} +#endif +#endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/debug/log.h b/metadef/third_party/graphengine/inc/framework/common/debug/log.h new file mode 100644 index 00000000..6d449919 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/debug/log.h @@ -0,0 +1,256 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ +#define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ + +#include + +#include "runtime/rt.h" +#include "common/string_util.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "ge/ge_api_error_codes.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) +#else +#include +#if defined(BUILD_VERSION_PERF) +#define DOMI_LOGE(fmt, ...) +#else +// The Android system has strict log control. Do not modify the log. +#define DOMI_LOGE(fmt, ...) \ + __android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#endif +#endif + +// ge marco +#define GE_LOGI_IF(condition, ...) \ + if ((condition)) { \ + GELOGI(__VA_ARGS__); \ + } + +#define GE_LOGW_IF(condition, ...) \ + if ((condition)) { \ + GELOGW(__VA_ARGS__); \ + } + +#define GE_LOGE_IF(condition, ...) \ + if ((condition)) { \ + DOMI_LOGE(__VA_ARGS__); \ + } + +// If expr is not SUCCESS, print the log and return the same value +#define GE_CHK_STATUS_RET(expr, ...) \ + do { \ + const ge::Status _status = (expr); \ + if (_status != ge::SUCCESS) { \ + DOMI_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0); + +// If expr is not SUCCESS, print the log and do not execute return +#define GE_CHK_STATUS(expr, ...) \ + do { \ + const ge::Status _status = (expr); \ + if (_status != ge::SUCCESS) { \ + DOMI_LOGE(__VA_ARGS__); \ + } \ + } while (0); + +// If expr is not SUCCESS, return the same value +#define GE_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const ge::Status _status = (expr); \ + if (_status != ge::SUCCESS) { \ + return _status; \ + } \ + } while (0); + +// If expr is not GRAPH_SUCCESS, print the log and return FAILED +#define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ + do { \ + if ((expr) != ge::GRAPH_SUCCESS) { \ + DOMI_LOGE(__VA_ARGS__); \ + return FAILED; \ + } \ + } while (0); + +// If expr is not SUCCESS, print the log and execute a custom statement +#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ + do { \ + const ge::Status _status = (expr); \ + GE_CHK_BOOL_EXEC(_status == SUCCESS, exec_expr, __VA_ARGS__); \ + } while (0); + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(_status, __VA_ARGS__); \ + return _status; \ + } \ + } while (0); + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + return _status; \ + } \ + } while (0); + +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + DOMI_LOGE(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGW(__VA_ARGS__); \ + exec_expr; \ + } \ + } +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + GELOGI(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is true, print logs and execute custom statements +#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + DOMI_LOGE(__VA_ARGS__); \ + exec_expr; \ + } \ + } +// If expr is true, print the Information log and execute a custom statement +#define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + GELOGI(__VA_ARGS__); \ + exec_expr; \ + } \ + } + +// If expr is not SUCCESS, print the log and execute the expression + return +#define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + DOMI_LOGE(__VA_ARGS__); \ + exec_expr; \ + return; \ + } \ + } + +// If expr is not SUCCESS, print the log and execute the expression + return _status +#define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (b) { \ + DOMI_LOGE(__VA_ARGS__); \ + exec_expr; \ + return _status; \ + } \ + } + +// If expr is not true, execute a custom statement +#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +// -----------------runtime related macro definitions------------------------------- +// If expr is not RT_ERROR_NONE, print the log +#define GE_CHK_RT(expr) \ + do { \ + rtError_t _rt_ret = (expr); \ + if (_rt_ret != RT_ERROR_NONE) { \ + DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ + } \ + } while (0); + +// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression +#define GE_CHK_RT_EXEC(expr, exec_expr) \ + { \ + rtError_t _rt_ret = (expr); \ + if (_rt_ret != RT_ERROR_NONE) { \ + DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ + exec_expr; \ + } \ + } + +// If expr is not RT_ERROR_NONE, print the log and return +#define GE_CHK_RT_RET(expr) \ + do { \ + rtError_t _rt_ret = (expr); \ + if (_rt_ret != RT_ERROR_NONE) { \ + DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ + return RT_ERROR_TO_GE_STATUS(_rt_ret); \ + } \ + } while (0); + +// If expr is true, execute exec_expr without printing logs +#define GE_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +// If make_shared is abnormal, print the log and execute the statement +#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ + try { \ + exec_expr0; \ + } catch (const std::bad_alloc &) { \ + DOMI_LOGE("Make shared failed"); \ + exec_expr1; \ + } + +#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/fmk_error_codes.h b/metadef/third_party/graphengine/inc/framework/common/fmk_error_codes.h new file mode 100644 index 00000000..ec1f26d0 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/fmk_error_codes.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ +#define INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ + +#include +#include + +#include "framework/common/fmk_types.h" +#include "register/register_error_codes.h" + +#define MODID_OMG 1 // OMG module ID +#define MODID_OME 2 // OME module ID +#define MODID_CALIBRATION 3 // Calibration module ID + +// Each module uses the following four macros to define error codes: +#define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value) +#define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value) +#define DECLARE_ERRORNO_CALIBRATION(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_CALIBRATION, name, value) + +#define DEF_ERRORNO(name, desc) const ErrorNoRegisterar g_##name##_errorno(name, desc); + +// Interface for Obtaining Error Code Description +#define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value) + +namespace domi { +class StatusFactory { + public: + static StatusFactory *Instance(); + + void RegisterErrorNo(uint32_t err, const std::string &desc); + + std::string GetErrDesc(uint32_t err); + + protected: + StatusFactory() {} + ~StatusFactory() {} + + private: + std::map err_desc_; +}; + +class ErrorNoRegisterar { + public: + ErrorNoRegisterar(uint32_t err, const std::string &desc) { StatusFactory::Instance()->RegisterErrorNo(err, desc); } + ~ErrorNoRegisterar() {} +}; + +// Common errocode +DECLARE_ERRORNO_COMMON(MEMALLOC_FAILED, 0); // 50331648 +DECLARE_ERRORNO_COMMON(CCE_FAILED, 2); // 50331650 +DECLARE_ERRORNO_COMMON(RT_FAILED, 3); // 50331651 +DECLARE_ERRORNO_COMMON(INTERNAL_ERROR, 4); // 50331652 +DECLARE_ERRORNO_COMMON(CSEC_ERROR, 5); // 50331653 +DECLARE_ERRORNO_COMMON(TEE_ERROR, 6); // 50331653 +DECLARE_ERRORNO_COMMON(UNSUPPORTED, 100); +DECLARE_ERRORNO_COMMON(OUT_OF_MEMORY, 101); + +// Omg errorcode +DECLARE_ERRORNO_OMG(PARSE_MODEL_FAILED, 0); +DECLARE_ERRORNO_OMG(PARSE_WEIGHTS_FAILED, 1); +DECLARE_ERRORNO_OMG(NOT_INITIALIZED, 2); +DECLARE_ERRORNO_OMG(TIMEOUT, 3); + +// Ome errorcode +DECLARE_ERRORNO_OME(MODEL_NOT_READY, 0); +DECLARE_ERRORNO_OME(PUSH_DATA_FAILED, 1); +DECLARE_ERRORNO_OME(DATA_QUEUE_ISFULL, 2); +} // namespace domi + +#endif // INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/fmk_types.h b/metadef/third_party/graphengine/inc/framework/common/fmk_types.h new file mode 100644 index 00000000..f84390da --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/fmk_types.h @@ -0,0 +1,23 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_FMK_TYPES_H_ +#define INC_FRAMEWORK_COMMON_FMK_TYPES_H_ + +#include "graph/types.h" +#include "register/register_types.h" + +#endif // INC_FRAMEWORK_COMMON_FMK_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/ge_inner_error_codes.h b/metadef/third_party/graphengine/inc/framework/common/ge_inner_error_codes.h new file mode 100644 index 00000000..3697a526 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/ge_inner_error_codes.h @@ -0,0 +1,319 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*lint -e* */ +#ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ +#define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ + +#include +#include +#include "ge/ge_api_error_codes.h" + +namespace ge { +// System ID +enum SystemIdType { SYSID_GE = 8 }; +// Runtime location +enum LogRuntime { + RT_HOST = 0b01, + RT_DEVICE = 0b10, +}; + +// Sub model +enum SubModuleId { + COMMON_MODULE = 0, + CLIENT_MODULE = 1, + INIT_MODULE = 2, + SESSION_MODULE = 3, + GRAPH_MODULE = 4, + ENGINE_MODULE = 5, + OPS_MODULE = 6, + PLUGIN_MODULE = 7, + RUNTIME_MODULE = 8, + EXECUTOR_MODULE = 9, + GENERATOR_MODULE = 10, +}; + +// Error code type +enum ErrorCodeType { + ERROR_CODE = 0b01, + EXCEPTION_CODE = 0b10, +}; + +// Error level +enum ErrorLevel { + COMMON_LEVEL = 0b000, + SUGGESTION_LEVEL = 0b001, + MINOR_LEVEL = 0b010, + MAJOR_LEVEL = 0b011, + CRITICAL_LEVEL = 0b100, +}; + +// Each module defines error codes using the following macros +#define GE_ERRORNO_COMMON(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, COMMON_MODULE, name, value, desc) +#define GE_ERRORNO_CLIENT(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, CLIENT_MODULE, name, value, desc) +#define GE_ERRORNO_INIT(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, INIT_MODULE, name, value, desc) +#define GE_ERRORNO_SESSION(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, SESSION_MODULE, name, value, desc) +#define GE_ERRORNO_GRAPH(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, GRAPH_MODULE, name, value, desc) +#define GE_ERRORNO_ENGINE(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, ENGINE_MODULE, name, value, desc) +#define GE_ERRORNO_OPS(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, OPS_MODULE, name, value, desc) +#define GE_ERRORNO_PLUGIN(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, PLUGIN_MODULE, name, value, desc) +#define GE_ERRORNO_RUNTIME(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, RUNTIME_MODULE, name, value, desc) +#define GE_ERRORNO_EXECUTOR(name, value, desc) \ + GE_ERRORNO(RT_DEVICE, ERROR_CODE, COMMON_LEVEL, SYSID_GE, EXECUTOR_MODULE, name, value, desc) +#define GE_ERRORNO_GENERATOR(name, value, desc) \ + GE_ERRORNO(RT_HOST, ERROR_CODE, COMMON_LEVEL, SYSID_GE, GENERATOR_MODULE, name, value, desc) + +// Get error code description +#define GE_GET_ERRORNO_STR(value) ge::StatusFactory::Instance()->GetErrDesc(value) + +// Common module error code definition +GE_ERRORNO_COMMON(MEMALLOC_FAILED, 0, "Failed to allocate memory!"); // 1343225856 +GE_ERRORNO_COMMON(PARAM_INVALID, 1, "Parameter's invalid!"); // 1343225857 +GE_ERRORNO_COMMON(CCE_FAILED, 2, "Failed to call CCE API!"); // 1343225858 +GE_ERRORNO_COMMON(RT_FAILED, 3, "Failed to call runtime API!"); // 1343225859 +GE_ERRORNO_COMMON(INTERNAL_ERROR, 4, "Internal errors"); // 1343225860 +GE_ERRORNO_COMMON(CSEC_ERROR, 5, "Failed to call libc_sec API!"); // 1343225861 +GE_ERRORNO_COMMON(TEE_ERROR, 6, "Failed to call tee API!"); // 1343225862 +GE_ERRORNO_COMMON(END_OF_SEQUENCE, 7, "End of sequence!"); // 1343225863 +GE_ERRORNO_COMMON(PATH_INVALID, 8, "Path is invalid!"); // 1343225864 + +// Error code for plugin manager +GE_ERRORNO_COMMON(GE_PLGMGR_PATH_INVALID, 30, "Path is invalid!"); // 1343225886 +GE_ERRORNO_COMMON(GE_PLGMGR_SO_NOT_EXIST, 31, "Failed to find any valid so file!"); // 1343225887 +GE_ERRORNO_COMMON(GE_PLGMGR_FUNC_NOT_EXIST, 32, "Failed to find any function!"); // 1343225888 +GE_ERRORNO_COMMON(GE_PLGMGR_INVOKE_FAILED, 33, "Failed to invoke any function!"); // 1343225889 + +GE_ERRORNO_COMMON(UNSUPPORTED, 100, "Parameter's unsupported!"); + +GE_ERRORNO_COMMON(OUT_OF_MEMORY, 101, "Out of memory!"); + +// Client module error code definition +GE_ERRORNO_CLIENT(GE_CLI_INIT_FAILED, 1, "GEInitialize Failed."); // 1343229953 +GE_ERRORNO_CLIENT(GE_CLI_FINAL_FAILED, 2, "GEFinalize Failed."); // 1343229954 +GE_ERRORNO_CLIENT(GE_CLI_SESS_CONSTRUCT_FAILED, 3, "Session constructor Failed."); // 1343229955 +GE_ERRORNO_CLIENT(GE_CLI_SESS_DESTROY_FAILED, 4, "Session destructor Failed."); // 1343229956 +GE_ERRORNO_CLIENT(GE_CLI_SESS_ADD_FAILED, 5, "Session AddGraph Failed."); // 1343229957 +GE_ERRORNO_CLIENT(GE_CLI_SESS_ADD_GRAPH_FAILED, 6, + "Session AddGraph Failed converting protobuf GraphProto."); // 1343229958 +GE_ERRORNO_CLIENT(GE_CLI_SESS_REMOVE_FAILED, 7, "Session RemoveGraph Failed."); // 1343229959 +GE_ERRORNO_CLIENT(GE_CLI_SESS_RUN_FAILED, 8, "Session RunGraph Failed."); // 1343229960 +GE_ERRORNO_CLIENT(GE_CLI_SESS_RUN_TENSOR_FAILED, 9, + "Session RunGraph Failed converting protobuf TensorProto."); // 1343229961 +GE_ERRORNO_CLIENT(GE_CLI_GE_ALREADY_INITIALIZED, 10, "GE is already initialized."); // 1343229962 +GE_ERRORNO_CLIENT(GE_CLI_GE_NOT_INITIALIZED, 11, "GE is not yet initialized or is finalized."); // 1343229963 + +// Init module error code definition +GE_ERRORNO_INIT(GE_MULTI_INIT, 0, "Multiple initializations are not supported."); // 1343234048 +GE_ERRORNO_INIT(GE_FINALIZE_NOT_INIT, 1, "Finalize is not allowed before initialization."); // 1343234049 +GE_ERRORNO_INIT(GE_MULTI_FINALIZE, 2, "Multiple finalizations are not supported."); // 1343234050 +GE_ERRORNO_INIT(GE_PROF_MULTI_INIT, 3, "Multiple profiling initializations are not supported."); // 1343234051 +GE_ERRORNO_INIT(GE_PROF_NOT_INIT, 4, "Profing initializations have not been done."); // 1343234052 +GE_ERRORNO_INIT(GE_PROF_MODE_CONFLICT, 5, + "Profiling command mode which is preferred is running, the api mode will not work."); // 1343234053 + +// Session module error code definition +GE_ERRORNO_SESSION(GE_SESS_INIT_FAILED, 0, "Failed to initialize session."); // 1343238144 +GE_ERRORNO_SESSION(GE_SESS_ALREADY_RUNNING, 1, "Session already running,not support parallel run."); // 1343238145 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_NOT_EXIST, 2, "Graph ID not exist."); // 1343238146 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_ALREADY_EXIST, 3, "Graph ID already exist."); // 1343238147 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_IS_RUNNING, 4, "Graph is running."); // 1343238148 +GE_ERRORNO_SESSION(GE_SESSION_NOT_EXIST, 5, "Can not find session with specific session id."); // 1343238149 +GE_ERRORNO_SESSION(GE_SESSION_MANAGER_NOT_INIT, 6, "Session manager has not been initialized."); // 1343238150 + +// Graph module error code definition +GE_ERRORNO_GRAPH(GE_GRAPH_INIT_FAILED, 0, "Failed to initialize graph."); // 1343242240 +GE_ERRORNO_GRAPH(GE_GRAPH_ALREADY_RUNNING, 1, "graph already running,not support parallel run."); // 1343242241 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_NOT_EXIST, 2, "graph ID not exist."); // 1343242242 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_ALREADY_EXIST, 3, "Graph ID already exist."); // 1343242243 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_IS_RUNNING, 4, "Graph is running."); // 1343242244 +GE_ERRORNO_GRAPH(GE_GRAPH_MALLOC_FAILED, 5, "Graph malloc failed."); // 1343242245 +GE_ERRORNO_GRAPH(GE_GRAPH_FREE_FAILED, 6, "Graph FREE failed."); // 1343242246 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_MALLOC_BUFFER, 7, "Graph FREE failed, not malloc buffer."); // 1343242247 +GE_ERRORNO_GRAPH(GE_GRAPH_PARAM_NULLPTR, 8, "Graph param is NULL."); // 1343242248 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, 9, "Get computeGraph by graphNode failed."); // 1343242249 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_NODE_NULL, 10, "Run graph node is null."); // 1343242250 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by graphNode failed."); // 1343242251 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 +GE_ERRORNO_GRAPH(GE_GRAPH_GET_IN_OUT_FAILED, 19, "OME GetInputOutputDescInfo failed."); // 1343242259 +GE_ERRORNO_GRAPH(GE_GRAPH_DATA_INPUT_FAILED, 20, "OME DataInput failed."); // 1343242260 +GE_ERRORNO_GRAPH(GE_GRAPH_EXECUTE_FAILED, 21, "Execute graph failed."); // 1343242261 +GE_ERRORNO_GRAPH(GE_GRAPH_DUPLICATE_ENGINE, 22, "Duplicate engine."); // 1343242262 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_SUBGRAPH, 23, "Empty sub graph info."); // 1343242263 +GE_ERRORNO_GRAPH(GE_GRAPH_EXECUTE_NOT_INIT, 24, "Call SetCondition first."); // 1343242264 +GE_ERRORNO_GRAPH(GE_GRAPH_PREPARE_FAILED, 25, "Prepare failed."); // 1343242265 +GE_ERRORNO_GRAPH(GE_GRAPH_SERIALIZE_FAILED, 26, "OMG SerializeModelDef failed."); // 1343242266 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVE_FAILED, 27, "OMG SaveModel failed."); // 1343242267 +GE_ERRORNO_GRAPH(GE_GRAPH_PRERUN_FAILED, 28, "PreRun failed."); // 1343242268 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ID_INVALID, 29, "Graph subGraph id is invalid."); // 1343242269 +GE_ERRORNO_GRAPH(GE_GRAPH_INFERSHAPE_FAILED, 30, "Prepare Graph infershape failed"); // 1343242270 +GE_ERRORNO_GRAPH(GE_GRAPH_ISNULL, 31, "RunGraph input compute graph is NULL."); // 1343242271 +GE_ERRORNO_GRAPH(GE_GRAPH_SYNC_MODEL_FAILED, 32, "Graph SyncExecuteModel failed."); // 1343242272 +GE_ERRORNO_GRAPH(GE_GRAPH_RUNGRAPH_FAILED, 33, "Graph RunGraph failed."); // 1343242273 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PARSE_DYN_OP_FAILED, 34, "Parse dynamic node config file failed"); // 1343242274 +GE_ERRORNO_GRAPH(GE_GRAPH_MULTI_SUBGRAPH_BUILD, 35, "Save model with multiple sub graph"); // 1343242275 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_NODE_NULL, 36, "Graph get graph node failed."); // 1343242276 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_INIT, 37, "Graph do not init."); // 1343242277 +GE_ERRORNO_GRAPH(GE_GRAPH_NULL_INPUT, 38, "input graph is null"); // 1343242278 +GE_ERRORNO_GRAPH(GE_GRAPH_TOPO_SORT_FAILED, 39, "topological sorting an partition failed"); // 1343242279 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_PARTITION, 40, "accessing an empty partition"); // 1343242280 +GE_ERRORNO_GRAPH(GE_GRAPH_UNSUPPORTED, 41, "unsupported feature in partition"); // 1343242281 +GE_ERRORNO_GRAPH(GE_GRAPH_ASSIGN_ENGINE_FAILED, 42, "assign engine failed"); // 1343242282 +GE_ERRORNO_GRAPH(GE_GRAPH_ADD_PLC_END_FAILED, 43, "add placeholder end node failed"); // 1343242283 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PARSE_OUT_NODE_FAILED, 44, "Parse out node failed."); // 1343242284 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_OP_PARSE_FAILED, 45, + "OMG parse dynamic node config file failed."); // 1343242285 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVE_WEIGHTS_FAILED, 46, "OMG Save Weights to Model failed."); // 1343242286 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_STRING_NAME, 47, "Empty string name."); // 1343242287 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, 48, "Empty variable-tensor table."); // 1343242288 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_ALREADY_EXIST, 49, "Variable already exist."); // 1343242289 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, 50, "Variable does not exist."); // 1343242290 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIONS_INVALID, 51, "Client session options is invalid."); // 1343242291 +GE_ERRORNO_GRAPH(GE_GRAPH_NO_OUTPUT_DESC_INFO, 52, "No output desc info."); // 1343242292 +GE_ERRORNO_GRAPH(GE_GRAPH_OUTPUT_DESCINFO_TENSOR_NUM_MISMATCH, 53, + "Number of output descinfo and tensor mismatch."); // 1343242293 +GE_ERRORNO_GRAPH(GE_GRAPH_FILENAMEPREFIX_INVALID, 54, "Graph Save Model fileNamePrefix is invalid."); // 1343242294 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_BUILT, 55, "Graph is not built before SaveModel."); // 1343242295 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVEMODEL_FAILED, 56, "Graph SaveModel failed."); // 1343242296 +GE_ERRORNO_GRAPH(GE_GRAPH_MEMORY_ALLOC_FAILED, 57, "Failed allocating memory for model file header."); // 1343242297 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_REMOVE_GRAPH_FAILED, 58, "Failed remove graph in node seacher."); // 1343242298 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_ADD_GRAPH_FAILED, 59, "Failed add graph in node seacher."); // 1343242299 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, + "Failed add graph in node seacher."); // 1343242300 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, + "Failed set graph finish rebuild in node searcher."); // 1343242301 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 + +// Engine_manager module error code definition +GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 +GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 +GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 + +// Optimize errocode +GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 +GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 + +// Ops module error code definition +GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 +GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 +GE_ERRORNO_OPS(GE_OPS_KERNEL_INFO_NOT_EXIST, 2, "OpsKernelInfo not exist."); // 1343250434 +GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_NOT_EXIST, 3, "OpsKernelInfoStore not exist."); // 1343250435 +GE_ERRORNO_OPS(GE_OPS_CALC_RUNNING_PARAM_FAILED, 4, "Failed to CalcOpRunningParam."); // 1343250436 +GE_ERRORNO_OPS(GE_OPS_GENERATE_TASK_FAILED, 5, "Failed to GenerateTask."); // 1343250437 +GE_ERRORNO_OPS(GE_OPS_OPTIMIZE_ORIGINAL_GRAPH_FAILED, 6, "Failed to OptimizeOriginalGraph."); // 1343250438 +GE_ERRORNO_OPS(GE_OPS_OPTIMIZE_FUSED_GRAPH_FAILED, 7, "Failed to OptimizeFusedGraph."); // 1343250439 +GE_ERRORNO_OPS(GE_OPS_ENGINE_IS_NOT_REGISTERED, 8, "Engine is not registered."); // 1343250440 +GE_ERRORNO_OPS(GE_OPS_GET_NO_VALID_SO, 9, + "There is no valid so about OpsKernelInfoStore or GraphOptimizer."); // 1343250441 +GE_ERRORNO_OPS(GE_OPS_GET_OPTIMIZE_BY_ENGINE_FAILED, 10, "Failed to get graphOptimizer by name."); // 1343250442 +GE_ERRORNO_OPS(GE_OPS_GET_OPTIMIZE_BY_PRIORITY_FAILED, 11, "Failed to get graphOptimizer by priority."); // 1343250443 +GE_ERRORNO_OPS(GE_OPS_LOAD_GE_OPTIMIZER_FAILED, 12, "Failed to load ge graphOptimizer."); // 1343250444 + +// Runtime module error code definition +GE_ERRORNO_RUNTIME(GE_RTI_DEVICE_ID_INVALID, 1, "device id is invalid"); +GE_ERRORNO_RUNTIME(GE_RTI_DEVICE_NOT_READY, 2, "set device failed, device not ready"); +GE_ERRORNO_RUNTIME(GE_RTI_MEMALLOC_FAILED, 3, "malloc memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_MODEL_NOT_LOADED, 4, "model has not been loaded"); +GE_ERRORNO_RUNTIME(GE_RTI_THREAD_POOL_IS_NULL, 5, "model excute failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_CREATE_HANDLE_FAILED, 6, "cce create handle failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_SET_STREAM_FAILED, 7, "cce set stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_RTMODEL_FAILED, 8, "call runtime create rtModel failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_STREAM_FAILED, 9, "call runtime create stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_BIND_STREAM_FAILED, 10, "call runtime bind stream to model failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_LABLE_FAILED, 11, "call runtime create lable failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_LOAD_COMPLETE_FAILED, 12, "call runtime model load complete failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_GET_TASK_ID_FAILED, 14, "call runtime get task id failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_LAUNCH_FAILED, 13, "call runtime kernel launch failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_LAUNCHEX_FAILED, 15, "call runtime kernel launchex failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_FUSION_START_FAILED, 16, "call runtime kernel fusion start failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_FUSION_END_FAILED, 17, "call runtime kernel fusion end failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABEL_SET_FAILED, 18, "call runtime lable set failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABLE_GOTO_FAILED, 19, "call runtime lable goto failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABLE_SWITCH_FAILED, 20, "call runtime lable switch failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_MANAGED_FAILED, 21, "call runtime mem alloc managed failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_MANAGED_FAILED, 22, "call runtime mem free managed failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_FREE_FAILED, 23, "call runtime free failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_STREAM_SYNC_FAILED, 24, "call runtime sync stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_EXCUTE_FAILED, 25, "call runtime model excute failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ASYNC_FAILED, 26, "call runtime mem async failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_HOST_FAILED, 27, "call runtime alloc host memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_HOST_FAILED, 28, "call runtime free host memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_DEVICE_FAILED, 29, "call runtime alloc device memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_DEVICE_FAILED, 30, "call runtime free device memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_FLUSH_CACHE_FAILED, 31, "call runtime flush cache failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_UNBIND_STREAM_FAILED, 32, "unbind rtstream from rtmodel failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_STREAM_FAILED, 33, "destory stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_LABEL_FAILED, 34, "destory label failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_MODEL_FAILED, 35, "destory model failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_TRANS_TENSOR_FAILED, 36, "call cce transfer tensor descriptor failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_TRANS_FILTER_FAILED, 37, "call cce transfer filter descriptor failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_UPDATE_KERNEL_ARGS_FAILED, 38, "call cce update kernel args failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_DESTORY_HANDLE_FAILED, 39, "destory handle failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_EVENT_FAILED, 40, "call rutime create event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_EVENT_RECORD_FAILED, 41, "call rutime event record failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_STREAM_WAIT_EVENT_FAILED, 42, "call rutime stream wait event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_BROADCAST_FAILED, 43, "call hccl hcom broadcast failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_ALL_GATHER_FAILED, 44, "call hccl hcom all gather failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_ALL_REDUCE_FAILED, 45, "call hccl hcom all reduce failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_EVENT_FAILED, 46, "destory rt event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_REDUCE_SCATTER_FAILED, 47, "call hccl hcom reduce scatter failed"); + +// Executor module error code definition +GE_ERRORNO_EXECUTOR(GE_EXEC_NOT_INIT, 1, "GE Executor is not yet initialized."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PATH_INVALID, 2, "Model file path is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_KEY_PATH_INVALID, 3, "Key file path of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_ID_INVALID, 4, "Model id is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_DATA_SIZE_INVALID, 5, "Data size of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PARTITION_NUM_INVALID, 6, "Partition number of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_QUEUE_ID_INVALID, 7, "Queue id of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, 8, "Model does not support encryption."); +GE_ERRORNO_EXECUTOR(GE_EXEC_READ_MODEL_FILE_FAILED, 9, "Failed to read model file."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_REPEATED, 10, "The model is loaded repeatedly."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_PARTITION_FAILED, 11, "Failed to load model partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, 12, "Failed to load weight partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_TASK_PARTITION_FAILED, 13, "Failed to load task partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, 14, "Failed to load kernel partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, 15, "Failed to allocate feature map memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate weight memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); +GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); +GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_P2P_MEM_FAILED, 20, "Failed to allocate P2P memory"); + +// Generator module error code definition +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, 2, "Graph manager add graph failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, 3, "Graph manager build graph failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, 4, "Graph manager finalize failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_SAVE_MODEL_FAILED, 5, "Graph manager save model failed."); + +#define RT_ERROR_TO_GE_STATUS(RT_ERROR) static_cast(RT_ERROR) +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/ge_types.h b/metadef/third_party/graphengine/inc/framework/common/ge_types.h new file mode 100644 index 00000000..40947f7a --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/ge_types.h @@ -0,0 +1,304 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_GE_TYPES_H_ +#define INC_FRAMEWORK_COMMON_GE_TYPES_H_ + +#include + +#include +#include + +#include "framework/common/fmk_error_codes.h" +#include "ge/ge_api_error_codes.h" +#include "external/graph/types.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +enum RuntimeType { + HOST = 0, + DEVICE = 1 +}; + +enum PerfLevel { + GEN_TASK_WITH_FUSION = -1, + GEN_TASK_WITHOUT_L2FUSION = 3, + GEN_TASK_WITHOUT_FUSION = 4 +}; + +enum FrameworkType { + CAFFE = 0, + MINDSPORE = 1, + TENSORFLOW = 3, + ANDROID_NN, + ONNX, + FRAMEWORK_RESERVED, +}; + +enum OpEngineType { + ENGINE_SYS = 0, // default engine + ENGINE_AICORE = 1, + ENGINE_VECTOR = 2, + ENGINE_AICUBE = 3, // not support + ENGINE_AIVECTOR = 4 // not support +}; + +enum InputAippType{ + DATA_WITHOUT_AIPP = 0, + DATA_WITH_STATIC_AIPP, + DATA_WITH_DYNAMIC_AIPP, + DYNAMIC_AIPP_NODE +}; + +const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; +const char *const GE_OPTION_EXEC_PLACEMENT = "ge.exec.placement"; + +// Data cache, including data address and length +struct DataBuffer { + public: + void *data; // Data address + uint64_t length; // Data length + bool isDataSupportMemShare = false; + DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) + : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} + + DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} +}; + +/// +/// @ingroup domi_ome +/// @brief External input data +/// +struct InputData { + uint32_t index; // Index of input data + uint32_t timestamp; // Data creation time + uint32_t timeout; // Processing timeout + uint32_t model_id; // Model ID required for data processing + uint64_t request_id = 0; // Request ID + std::vector blobs; // Actual input data, currently only supports one input + bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false + std::string batch_label; // Gear used for current inference in dynamic batch scene +}; + +/// Output result structure definition +struct OutputData { + uint32_t index; // Index of input data + uint32_t model_id; // The model ID corresponding to the processing result + /// Output data cache, arranged in sequence of output operators. + /// If the operator has multiple outputs, + /// the data buffer order of the operator is the same as that defined in the + /// offline model + std::vector blobs; +}; + +// The definition of command data structure +struct Command { + std::string cmd_type; // Command type + std::vector cmd_params; // Command params + uint64_t module_index; // prof module +}; + +// The definition of I/O shape description +struct ShapeDescription { + int64_t num = 0; + int64_t channel = 0; + int64_t height = 0; + int64_t width = 0; + std::vector dims; +}; + +// Definition of input and output description information +struct InputOutputDescInfo { + std::string name; + uint64_t size; + uint32_t data_type; + ShapeDescription shape_info; +}; + +// Definition of model io dims +struct InputOutputDims { + std::string name; + size_t dim_num; + uint32_t size; + std::vector dims; +}; + +// Definition of model io dims +struct OriginInputInfo { + Format format; + DataType data_type; + uint32_t dim_num; +}; + +// The structure of AIPP info +struct AippConfigInfo { + int8_t aipp_mode; + int8_t input_format; + int32_t src_image_size_w; + int32_t src_image_size_h; + int8_t crop; + int32_t load_start_pos_w; + int32_t load_start_pos_h; + int32_t crop_size_w; + int32_t crop_size_h; + int8_t resize; + int32_t resize_output_w; + int32_t resize_output_h; + int8_t padding; + int32_t left_padding_size; + int32_t right_padding_size; + int32_t top_padding_size; + int32_t bottom_padding_size; + int8_t csc_switch; + int8_t rbuv_swap_switch; + int8_t ax_swap_switch; + int8_t single_line_mode; + int32_t matrix_r0c0; + int32_t matrix_r0c1; + int32_t matrix_r0c2; + int32_t matrix_r1c0; + int32_t matrix_r1c1; + int32_t matrix_r1c2; + int32_t matrix_r2c0; + int32_t matrix_r2c1; + int32_t matrix_r2c2; + int32_t output_bias_0; + int32_t output_bias_1; + int32_t output_bias_2; + int32_t input_bias_0; + int32_t input_bias_1; + int32_t input_bias_2; + int32_t mean_chn_0; + int32_t mean_chn_1; + int32_t mean_chn_2; + int32_t mean_chn_3; + float min_chn_0; + float min_chn_1; + float min_chn_2; + float min_chn_3; + float var_reci_chn_0; + float var_reci_chn_1; + float var_reci_chn_2; + float var_reci_chn_3; + int8_t support_rotation; + uint32_t related_input_rank; + uint32_t max_src_image_size; +}; + +// The structure of offline Modeldata +struct ModelData { + void *model_data = nullptr; // Model binary data start addr + uint32_t model_len = 0; // Model binary data length + int32_t priority = 0; // Model priority + std::string key; // Key path for encrypt model, Empty for unencrypt + std::string om_name; // om file name, used for data dump +}; + +// The definition of Model information +struct ModelInfo { + uint32_t version = 0; + std::string name; + bool is_encrypt = 0; // 0:unencrypt, 1:encrypt + std::vector input_desc; + std::vector output_desc; + uint8_t reserved[3] = {0}; // 3-byte reserved field +}; + +// Asynchronous callback interface, implemented by the caller +class ModelListener { + public: + virtual ~ModelListener() {} + /// + /// @brief Asynchronous callback interface + /// @param [in] model_id Model ID of the callback + /// @param [in] data_index Index of the input_data + /// @param [in] resultCode Execution results + /// + virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, + std::vector &outputs) = 0; +}; + +// OMM configuration item +struct Options { + int64_t session_id; + int32_t device_id; + std::string job_id; + bool isUseHcom; + bool isUseHvd; + bool deployMode; + bool isAICPUMode; + bool enable_atomic; + std::string podName; + int64_t rankId; + std::string rankTableFile; + int32_t ge_hccl_flag = 0; + int32_t physical_device_id; + std::string profiling_mode; + std::string profiling_options; +}; + +// Profiling info of task +struct TaskDescInfo { + std::string model_name; + std::string op_name; + uint32_t block_dim; + uint32_t task_id; + uint32_t stream_id; +}; + +// Profiling info of graph +struct ComputeGraphDescInfo { + std::string model_name; + std::string op_name; + std::string op_type; + std::vector input_format; + std::vector> input_shape; + std::vector input_data_type; + std::vector output_format; + std::vector> output_shape; + std::vector output_data_type; +}; + +struct OpDescInfo { + std::string op_name; + std::string op_type; + uint32_t task_id; + uint32_t stream_id; + std::vector input_format; + std::vector> input_shape; + std::vector input_data_type; + std::vector input_addrs; + std::vector input_size; + std::vector output_format; + std::vector> output_shape; + std::vector output_data_type; + std::vector output_addrs; + std::vector output_size; +}; +struct ModelDumpConfig { + std::string model_name; + std::vector layers; +}; + +struct DumpConfig { + std::string dump_path; + std::string dump_mode; + std::string dump_status; + std::string dump_op_switch; + std::vector dump_list; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/op/attr_value_util.h b/metadef/third_party/graphengine/inc/framework/common/op/attr_value_util.h new file mode 100644 index 00000000..e3803b78 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/op/attr_value_util.h @@ -0,0 +1,160 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ +#define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ + +#include +#include +#include + +#include "graph/debug/ge_attr_define.h" +#include "proto/om.pb.h" + +using domi::AttrDef; +using domi::AttrDef_ListValue; +using domi::ModelDef; +using domi::NamedAttrs; +using domi::OpDef; + +namespace ge { +using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; +using AttrDefPair = ::google::protobuf::MapPair; + +void AddOpAttr(const std::string &key, AttrDef &attr, OpDef *opdef); +// DEFINE_ADD_ATTR_VALUE +void AddOpAttr(const std::string &key, const std::string &value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const char *value, AttrDefMap *attrs); +void AddOpAttr(const char *key, const char *value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const uint32_t value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const int32_t value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const int64_t value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const float value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const double value, AttrDefMap *attrs); +void AddOpAttr(const std::string &key, const bool value, AttrDefMap *attrs); + +void AddOpAttr(const std::string &key, const AttrDef_ListValue &value, AttrDefMap *attrs); + +// DEFINE_ADD_ATTR_VALUE +void AddOpAttr(const std::string &key, const std::string &value, OpDef *opdef); +void AddOpAttr(const std::string &key, const char *value, OpDef *opdef); +void AddOpAttr(const char *key, const char *value, OpDef *opdef); +void AddOpAttr(const std::string &key, const uint32_t value, OpDef *opdef); +void AddOpAttr(const std::string &key, const int32_t value, OpDef *opdef); +void AddOpAttr(const std::string &key, const int64_t value, OpDef *opdef); +void AddOpAttr(const std::string &key, const float value, OpDef *opdef); +void AddOpAttr(const std::string &key, const double value, OpDef *opdef); +void AddOpAttr(const std::string &key, const bool value, OpDef *opdef); + +void AddOpAttr(const std::string &key, const AttrDef_ListValue &value, OpDef *opdef); + +void AddOpBytesAttr(const std::string &key, const void *value, size_t size, OpDef *opdef); + +// DEFINE_ADD_ATTR_VALUE_LIST +void AddOpAttrList(const std::string &key, const double value, AttrDefMap *attrs); +void AddOpAttrList(const std::string &key, const float value, AttrDefMap *attrs); +void AddOpAttrList(const std::string &key, const uint32_t value, AttrDefMap *attrs); +void AddOpAttrList(const std::string &key, const int32_t value, AttrDefMap *attrs); +void AddOpAttrList(const std::string &key, const std::string value, AttrDefMap *attrs); +void AddOpAttrList(const std::string &key, const double value, OpDef *opdef); +void AddOpAttrList(const std::string &key, const float value, OpDef *opdef); +void AddOpAttrList(const std::string &key, const uint32_t value, OpDef *opdef); +void AddOpAttrList(const std::string &key, const int32_t value, OpDef *opdef); +void AddOpAttrList(const std::string &key, const bool value, OpDef *opdef); +void AddOpAttrList(const std::string &key, const int64_t value, OpDef *opdef); + +void AddOpAttrList(const std::string &key, const std::string &value, OpDef *opdef); + +bool GetOpAttr(const std::string &key, std::string *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, int64_t *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, uint32_t *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, float *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, double *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, bool *value, const OpDef *opdef); +bool GetOpAttr(const std::string &key, AttrDef_ListValue *value, const OpDef *opdef); + +uint32_t GetOpAttrListSize(const std::string &key, std::string value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, int32_t value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, int64_t value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, uint32_t value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, float value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, double value, const OpDef *opdef); +uint32_t GetOpAttrListSize(const std::string &key, bool value, const OpDef *opdef); + +bool GetBytesAttr(const std::string &key, std::string *value, const OpDef *opdef); +bool GetBytesAttr(const std::string &key, std::string *value, const ModelDef *model_def); + +void AddModelAttr(const std::string &key, const std::string &value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const char *value, ModelDef *model_def); +void AddModelAttr(const char *key, const char *value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const uint32_t value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const int32_t value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const int64_t value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const float value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const double value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const bool value, ModelDef *model_def); +void AddModelAttr(const std::string &key, const void *value, size_t size, ModelDef *model_def); +void AddModelAttr(const std::string &key, const AttrDef_ListValue &value, ModelDef *model_def); + +void AddModelAttrList(const std::string &key, const double value, ModelDef *model_def); +void AddModelAttrList(const std::string &key, const float value, ModelDef *model_def); +void AddModelAttrList(const std::string &key, const uint32_t value, ModelDef *model_def); +void AddModelAttrList(const std::string &key, const int32_t value, ModelDef *model_def); +void AddModelAttrList(const std::string &key, const std::string &value, ModelDef *model_def); + +bool GetModelAttr(const std::string &key, std::string *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, int32_t *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, int64_t *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, uint32_t *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, float *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, double *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, bool *value, const ModelDef *model_def); +bool GetModelAttr(const std::string &key, AttrDef_ListValue *value, const ModelDef *model_def); + +bool HasOpAttr(const OpDef *opdef, const std::string &attr_name); + +void SetAttrDef(const std::string &value, AttrDef *out); +void SetAttrDef(const char *value, AttrDef *out); +void SetAttrDef(const uint32_t value, AttrDef *out); +void SetAttrDef(const int32_t value, AttrDef *out); +void SetAttrDef(const float value, AttrDef *out); +void SetAttrDef(const double value, AttrDef *out); +void SetAttrDef(const bool value, AttrDef *out); +void SetAttrList(const std::string &value, AttrDef *out); +void SetAttrList(const bool value, AttrDef *out); +void SetAttrList(const float value, AttrDef *out); +void SetAttrList(const double value, AttrDef *out); +void SetAttrList(const uint32_t value, AttrDef *out); + +bool GetAttrDefValue(const std::string &key, std::string *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, int32_t *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, int64_t *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, uint32_t *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, float *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, double *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, bool *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, AttrDef_ListValue *value, const AttrDefMap &attr); +bool GetAttrDefValue(const std::string &key, NamedAttrs *&value, AttrDefMap *attr); +bool GetAttrDefValue(const std::string &key, const NamedAttrs *&value, const AttrDefMap &attr); + +bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const AttrDefMap &attr); +bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); +bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); +bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); +} + +#endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/op/ge_op_utils.h b/metadef/third_party/graphengine/inc/framework/common/op/ge_op_utils.h new file mode 100644 index 00000000..4718b180 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/op/ge_op_utils.h @@ -0,0 +1,296 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ +#define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ + +#include +#include +#include + +#include "common/op/attr_value_util.h" +#include "register/register_types.h" +#include "register/register_error_codes.h" +#include "common/util.h" +#include "graph/attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "proto/insert_op.pb.h" + +namespace ge { +using namespace cce; +using domi::Status; + +// Add Sub Mul +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t ADD_INPUT_NUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SUB_INPUT_NUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MUL_INPUT_NUM; + +// Permute +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t PERMUTE_ORDER_NUM; + +// Ssd PriroBox +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const double SSD_PRIORBOX_ASPECT_RATIO_VALUE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STRIDEDSLICE_INPUT_NUM; + +// Switch +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_INPUT_NUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_OUTPUT_NUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_FALSE_OUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TRUE_OUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; + +// FunctionOp +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; + +class OpUtils { + public: + /// + /// @ingroup domi_ome + /// @brief Check whether check_value is in [min_enum_value, max_enum_value] + /// @return true Within + /// @return false out of range + // + static inline bool CheckEnumValid(int32_t check_value, int32_t min_enum_value, int32_t max_enum_value) { + return check_value < min_enum_value ? false : (check_value >= max_enum_value ? false : true); + } + /// + /// @ingroup domi_omg + /// @brief Convert the dimension of array according to different format + /// @param [in] src_format src_shape format + /// @param [in] src Dimension array to be converted + /// @param [in] dst_format Target format after conversion + /// @param [out] dst Dimension array after conversion + /// @return SUCCESS success + /// @return FAILED fail + /// + static bool ConvertDim(ccTensorFormat_t src_format, const std::vector &src, ccTensorFormat_t dst_format, + std::vector &dst); + /// + /// @ingroup domi_omg + /// @brief Determine whether to manually calculate the tensor size based on the values of format and dim + /// @param [in] format, Format information of the tensor + /// @param [in] real_dim_cnt, Tensor dim + /// @return true Manually calculate the size based on dim and datatype + /// @return false skip + /// + static bool IsComputDimsSize(const int32_t format, const uint32_t real_dim_cnt); + /// + /// @ingroup domi_ome + /// @brief Initialize the tensor description, which is used for input and output. + /// @param [in] model_tensor Tensor information defined by the offline model + /// @param [out] cc_tensor Tensor definition used by CC + /// @return SUCCESS success + /// @return FAILED fail + /// + static Status InitTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccTensorDescriptor_t &cc_tensor); + /// + /// @ingroup domi_ome + /// @brief Initialize the tensor description, which is used for input and output. + /// @param [in] model_tensor Tensor information defined by the offline model + /// @param [in] dst_data_type data_type of the target cc_tensor + /// @param [out] cc_tensor Tensor definition used by CC + /// @return SUCCESS success + /// @return FAILED fail + /// + static Status InitTensorDescriptor(const ge::GeTensorDesc &model_tensor, int32_t dst_data_type, + ccTensorDescriptor_t &cc_tensor); + /// + /// @ingroup domi_ome + /// @brief Initialize the tensor description for bias. + /// @param [in] model_tensor Tensor information defined by the offline model + /// @param [out] cc_tensor Tensor definition used by CC + /// @return SUCCESS success + /// @return FAILED fail + /// + /// + static Status InitTensorDescriptor(const ge::GeTensor &model_tensor, ccTensorDescriptor_t &cc_tensor); + /// + /// @ingroup domi_ome + /// @brief Initialize the tensor description for bias. + /// @param [in] model_tensor Tensor information defined by the offline model + /// @param [in] dst_data_type data_type of the target cc_tensor + /// @param [out] cc_tensor Tensor definition used by CC + /// @return SUCCESS success + /// @return FAILED fail + /// + static Status InitTensorDescriptor(const ge::GeTensor &model_tensor, int32_t dst_data_type, + ccTensorDescriptor_t &cc_tensor); + + static Status InitTensorDescriptor(int32_t format, int32_t data_type, const std::vector &dim, + ccTensorDescriptor_t &cc_tensor, uint32_t real_dim_cnt = 4); + /// + /// @ingroup domi_ome + /// @brief Destroys a tensor + /// @param [inout] cc_tensor Tensor definition used by CC + /// + static void DestroyTensorDescriptor(ccTensorDescriptor_t &cc_tensor) noexcept; + + /// + /// @ingroup domi_ome + /// @brief Destroys a tensor + /// @param [inout] cc_filter cc_filter Definition of the filter used by CC + /// + static void DestroyFilterDescriptor(ccFilterDescriptor_t &cc_filter); + + /// + /// @ingroup domi_ome + /// @brief Initializing Filter Description + /// @param [in] model_filter Filter information defined in the offline model + /// @param [out] cc_filter Definition of the filter used by CC + /// @return SUCCESS success + /// @return FAILED fail + /// + static Status InitFilterDescriptor(const ge::GeTensor &model_filter, ccFilterDescriptor_t &cc_filter); + + /// + /// @brief Extract AIPP parameters from AttrDefMap and splice them + /// @param [in] aipp_attr attr of operator + /// @param [out] aipp_params aipp parameters + /// @return enum of tagCCAippInputFormat + /// + static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); + static Status TransferDim(const std::vector &dim, std::vector &dim_vector); + template + static void SliceData(const std::vector &input, int64_t chunk_size, std::vector &output, + int64_t begin, int64_t out_dim, int64_t stride); + template + static Status SetDataByDataType(size_t out_size, const std::vector &chunk_input, + const std::vector &chunk_output, GeTensor *output); + template + static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector &input_dims, + const std::vector &begin, const std::vector &output_dims, + ge::GeTensor *output, const std::vector &stride); + static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, + std::vector &begin, std::vector &output_dims, ge::GeTensor *output, + std::vector &stride); + + /// + /// @ingroup domi_omg + /// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w] + /// @param [in] input Weight data in HWCK format + /// @param [in] H value of H dimension + /// @param [in] W value of W dimension + /// @param [in] C value of C dimension + /// @param [in] K value of K dimension + /// @param [out] output Data pointer after conversion. The format is KCHW. + /// + static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); + /// + /// @ingroup domi_omg + /// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k]. + /// @param [in] input Weight data in HWCK format + /// @param [in] K value of K dimension + /// @param [in] C value of C dimension + /// @param [in] H value of H dimension + /// @param [in] W value of W dimension + /// @param [out] output Data pointer after conversion. The format is HWCK + /// + static void TransDataKCHW2HWCK(const void *input, int64_t K, int64_t C, int64_t H, int64_t W, void *output); + /// + /// @ingroup domi_omg + /// @brief Initialize the input and output description of the data node which is applied to filter weight in the + /// training network + /// @param [in] model_tensor input and output tensor information + /// @param [out] cc_tensor Tensor in CCE format after conversion + /// + static Status InitFilterTensorDescriptor(const ge::GeTensorDesc &model_tensor, ccFilterDescriptor_t &cc_tensor); + + static void SetTensorDescriptorAllOffsetQuantizeInfo(const GeTensorDesc &tensor, ccTensorDescriptor_t cc_tensor); + static vector GetWeights(const ge::Node &node); + static vector GetWeights(ge::ConstNodePtr node); + static vector MutableWeights(const ge::Node &node); + static vector MutableWeights(const ge::NodePtr node); + static Status SetWeights(ge::Node &node, const vector &weights); + static Status SetWeights(ge::NodePtr node, const vector &weights); + static Status GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, DataType type, std::vector &dims); + + private: + friend class CceTensorDescriptor; + static uint32_t GetRealDimCnt(const GeTensorDesc &tensor_desc); +}; + +class CceTensorDescriptor; + +using CceTensorDescriptorPtr = std::shared_ptr; + +class CceTensorDescriptor { + public: + explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); + CceTensorDescriptor(const CceTensorDescriptor &) = delete; + CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; + + ~CceTensorDescriptor(); + + ccTensorDescriptor_t GetPtr() { return cc_tensor_; } + + /// + /// @brief Initializes the tensor based on shape information. + /// @param[in] format data permutation format + /// @param[in] data_type Data Type + /// @param[in] dim dim information + /// @return return code + /// + Status InitTensor(int32_t format, int32_t data_type, const std::vector &dims); + + Status InitTensor(int32_t format, int32_t data_type, const ge::GeShape &shape); + + /// + /// @brief get format of tensor + /// @param[out] format format of the tensor + /// @return return code + /// + Status GetFormat(ccTensorFormat_t *format); + + /// + /// @brief Obtains the size of the tensor. + /// @param[out] size size of Tensor + /// @return return code + /// + Status GetTensorSizeInBytes(uint32_t *size); + + /// + /// @brief transform tensor between 4d(NCHW) and 5d(NC1HWC0) + /// @param [in] xDesc descriptor of input tensor + /// @param [in] x point to input data in host memory + /// @param [in] dataTypeTransmode mode of data type transform + /// @param [in] yDesc descriptor of output tensor + /// @param [in|out] y point to output data in host memory + /// @param [in] ySizeInBytes size of outputData + /// @return return code + /// + static Status TransTensor(const ccTensorDescriptor_t xDesc, const void *x, const CceTensorDescriptorPtr &yDesc, + void *y, uint32_t ySizeInBytes); + + /// + /// @brief CceTensorDescriptor Static Constructor + /// @return CceTensorDescriptor smart pointer + /// + static CceTensorDescriptorPtr Create(); + + ccTensorDescriptor_t cc_tensor_ = nullptr; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/op/op_parser_util.h b/metadef/third_party/graphengine/inc/framework/common/op/op_parser_util.h new file mode 100644 index 00000000..49b4350a --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/op/op_parser_util.h @@ -0,0 +1,425 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ +#define INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ + +#include +#include +#include +#include + +namespace ge { +// general +const float DEFAULT_ALPHA_VALUE = 1.0; +const float DEFAULT_BETA_VALUE = 0.0; +const uint32_t NORMAL_INPUT_NUM = 1; +const uint32_t NORMAL_OUTPUT_NUM = 1; +const uint32_t NORMAL_WORKSPACE_NUM = 0; +const int32_t NORMAL_1D_DIM_NUM = 1; +const int32_t NORMAL_SCALE_DIM_NUM = 0; +const int NORMAL_TENSOR_FORMAT = static_cast(cce::CC_TENSOR_NC1HWC0); +const int NORMAL_TENSOR_SIZE = 4; +const int NORMAL_DEVICE_DATA_TYPE = static_cast(cce::CC_DATA_HALF); +const int DEFAULT_POOLING_MODE = static_cast(cce::CC_POOLING_MAX); +const uint32_t DEFAULT_REAL_DIM_CNT = 4; + +// const +const uint32_t CONST_OP_INPUT_NUM = 0; +const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1; + +// MatMul +const uint32_t MATMUL_INPUT_NUM = 2; + +// ActivationGrad +const int32_t ACTIVATIONGRAD_INPUT_NUM = 2; + +// FusedBatchNorm +const int32_t FUSED_BATCH_NORM_WORKSPACE_NUM = 1; +const int32_t FUSED_BATCH_NORM_INPUT_NUM = 5; +const int32_t FUSED_BATCH_NORM_OUTPUT_NUM = 5; +// FusedBatchNormGrad +const int32_t FUSEDBATCHNORMGRAD_WORKSPACE_NUM = 1; +const int32_t FUSEDBATCHNORMGRAD_INPUT_NUM = 5; +const int32_t FUSEDBATCHNORMGRAD_OUTPUT_NUM = 3; + +// conv +const uint32_t CONVOLUTION_WORKSPACE_NUM = 1; +const uint32_t CONVOLUTION_PAD_SIZE = 4; +const uint32_t CONVOLUTION_STRIDE_SIZE = 2; +const uint32_t CONVOLUTION_DILATION_SIZE = 2; +const int32_t CONVOLUTION_ADJ_SIZE = 2; +const int32_t CONVOLUTION_TARGET_SHAPE_SIZE = 2; + +// ConvGradFilter +const uint32_t CONVGRADFILTER_WORKSPACE_NUM = 1; +const uint32_t CONVGRADFILTER_INPUT_NUM = 3; + +// Pooling +const uint32_t POOLING_WINDOW_SIZE = 2; +const uint32_t POOLING_STRIDE_SIZE = 2; +const uint32_t POOLING_PAD_SIZE = 4; + +// Add Sub Mul +const uint32_t ADD_INPUT_NUM = 2; +const uint32_t SUB_INPUT_NUM = 2; +const uint32_t MUL_INPUT_NUM = 2; +const uint32_t DIV_INPUT_NUM = 2; +const uint32_t ADD_WORKSPACE_NUM = 1; +const uint32_t SUB_WORKSPACE_NUM = 1; +const uint32_t MUL_WORKSPACE_NUM = 1; +const uint32_t DIV_WORKSPACE_NUM = 1; + +const int32_t DEFAULT_AXIS_VALUE = -1; + +const int32_t RESHAPE_AXIS_DEFAULT_VALUE = 0; +const int32_t RESHAPE_NUM_AXES_DEFAULT_VALUE = -1; +const uint32_t RESHAPE_WORKSPACE_NUM = 1; + +const uint32_t FLATTEN_WORKSPACE_NUM = 1; + +const int32_t CONCAT_MIN_INPUT_SIZE = 1; +const int32_t CONCAT_DEFAULT_AXIS = 1; +const uint32_t CONCAT_WORKSPACE_NUM = 1; + +// The value for LRN parameters +const uint32_t LRN_DEFAULT_NORM_REGION = 0; +const float LRN_DEFAULT_K = 1.0; +const uint32_t LRN_DEFAULT_LOCAL_SIZE = 5; +const float LRN_DEFAULT_ALPHA = 1.0; +const float LRN_DEFAULT_BETA = 0.75; + +/// +/// @ingroup domi_common +/// @brief roipooling default value +/// +const uint32_t ROIPOOLING_DEFAULT_POOLED_H = 0; +const uint32_t ROIPOOLING_DEFAULT_POOLED_W = 0; +const float ROIPOOLING_DEFAULT_SPATIAL_SCALE = 1; +const int32_t ROIPOOLING_DEFAULT_SAMPLING_RATIO = -1; + +// DetectionOutput +const int32_t DETECTIONOUTPUT_INPUT_SIZE = 3; +const int32_t DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const int DETECTIONOUTPUT_CLASS_NUM = 20; // Number of background categories +const int DETECTIONOUTPUT_NUM_CLASSES_DEFAULT_VALUE = 21; +const float DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const float DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.8; + +// Proposal +const int32_t PROPOSAL_INPUT_SIZE = 3; +const int32_t PROPOSAL_OUTPUT_MAX_SIZE = 2; +const int32_t PROPOSAL_WORKSPACE_NUM = 1; +const float PROPOSAL_BASE_SIZE_DEFAULT_VALUE = 16; +const float PROPOSAL_RATIO_DIM_0_DEFAULT_VALUE = 0.5; +const float PROPOSAL_RATIO_DIM_1_DEFAULT_VALUE = 1; +const float PROPOSAL_RATIO_DIM_2_DEFAULT_VALUE = 2; +const float PROPOSAL_SCALE_DIM_0_DEFAULT_VALUE = 8; +const float PROPOSAL_SCALE_DIM_1_DEFAULT_VALUE = 16; +const float PROPOSAL_SCALE_DIM_2_DEFAULT_VALUE = 32; +const float PROPOSAL_MIN_SIZE_DEFAULT_VALUE = 16; +const int PROPOSAL_PRE_NMS_TOPN_DEFAULT_VALUE = 6000; +const int PROPOSAL_POST_NMS_TOPN_DEFAULT_VALUE = 304; +const float PROPOSAL_NMS_THRESH_DEFAULT_VALUE = 0.7; +const float PROPOSAL_FILTER_THRESH_DEFAULT_VALUE = 0; + +// TVM OP +const uint32_t DEFAULT_KERNEL_BLOCK_DIM = 1; + +// Softmax +const int32_t SOFTMAX_WORKSPACE_NUM = 1; + +// SoftmaxCrossEntropy +const int32_t SOFTMAXCROSSENTROPY_INPUT_NUM = 2; +const int32_t SOFTMAXCROSSENTROPY_OUTPUT_NUM = 2; + +// Permute +const int32_t PERMUTE_INPUT_NUM = 1; +const int32_t PERMUTE_OUTPUT_NUM = 1; +const int32_t PERMUTE_WORKSPACE_NUM = 1; +const int32_t PERMUTE_ORDER_NUM = 4; + +// Ssd normalize +const int SSD_NORMALIZE_INPUT_SIZE = 1; +const float SSD_NORMALIZE_EPS_DEFAULT_VALUE = 2e-7; + +// SsdPriroBox +const int32_t SSD_PRIOR_BOX_WORKSPACE_NUM = 1; +const int32_t SSD_PRIOR_BOX_INPUT_NUM = 2; +const bool SSD_PRIOR_BOX_FLIP_VALUE = true; +const bool SSD_PRIOR_BOX_CLIP_VALUE = false; +const double SSD_PRIOR_BOX_ASPECT_OFFSET_VALUE = 0.5; +const double SSD_PRIORBOX_VARIANCE_VALUE = 0.1; +const double SSD_PRIORBOX_VARIANCE_SIZE_ONE = 1; +const double SSD_PRIORBOX_VARIANCE_SIZE_FOUR = 4; +const double SSD_PRIORBOX_ASPECT_RATIO_VALUE = 1.0; +const int SSD_PRIOR_BOX_CODETYPE_CORNER_VALUE = 1; +const int SSD_PRIOR_BOX_CODETYPE_CENTER_SIZE_VALUE = 2; +const int SSD_PRIOR_BOX_CODETYPE_CORNER_SIZE_VALUE = 3; + +// Ssd DetectionOutput +const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE = 3; +const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2; +const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3; +const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM_AFTER_FUSION = 0; +const bool SSD_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true; +const int32_t SSD_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0; +const float SSD_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const int32_t SSD_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; +const float SSD_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; +const int SSD_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast(cce::CC_BOX_CENTER_SIZE); +const int32_t SSD_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; +const bool SSD_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; +const float SSD_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; + +// Refinedet DetectionOutput +const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE = 5; +const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2; +const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3; +const int32_t REFINEDET_DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const bool REFINEDET_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true; +const int32_t REFINEDET_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0; +const float REFINEDET_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const int32_t REFINEDET_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; +const float REFINEDET_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; +const bool REFINEDET_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; +const int REFINEDET_DETECTIONOUTPUT_CODE_TYPE_DEFAULT_VALUE = static_cast(cce::CC_BOX_CENTER_SIZE); +const int32_t REFINEDET_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; +const float REFINEDET_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; +const float REFINEDET_DETECTIONOUTPUT_OBJECTNESS_SCORE_DEFAULT_VALUE = 0; + +// Channel axpy +const int32_t CHANNEL_AXPY_INPUT_NUM = 3; +const int32_t CHANNEL_AXPY_INPUT_DIM_SIZE = 4; +const int32_t CHANNEL_AXPY_WORKSPACE_NUM = 1; + +// Psroi pooling +const int PSROI_POOLING_INPUT_COUNT = 2; +const int PSROI_POOLING_WORKSPACE_NUM = 1; + +// MaxPoolWithArgmax +const uint32_t MAX_POOL_WITH_ARGMAX_OUTPUT_NUM = 2; +const uint32_t MAX_POOL_GRAD_WITH_ARGMAX_INPUT_NUM = 3; + +// AvgPoolGrad +const uint32_t AVG_POOL_GRAD_INPUT_NUM = 2; + +// ROIAlign +const int32_t ROIALIGN_INPUT_SIZE = 2; +const int32_t ROIALIGN_WORKSPACE_NUM = 1; +const int32_t ROIALIGN_DEFAULT_POOLED_H = 1; +const int32_t ROIALIGN_DEFAULT_POOLED_W = 1; + +// Correlation +const uint32_t CORRELATION_INPUT_NUM = 2; +const int CORRELATION_WORKSPACE_NUM = 1; + +// Detectionpostprocess +const int32_t POSTPROCESS_INPUT_SIZE = 4; +const int32_t POSTPROCESS_OUTPUT_SIZE = 2; +const int32_t POSTPROCESS_WORKSPACE_NUM = 1; +const uint32_t POSTPROCESS_CLS_NUM_DEFAULT_VALUE = 12; +const uint32_t POSTPROCESS_POST_NMS_TOPN_DEFAULT_VALUE = 100; +const float POSTPROCESS_NMS_THRESH_DEFAULT_VALUE = 0.3; +const float POSTPROCESS_CONF_THRESH_DEFAULT_VALUE = 0.5; +const float POSTPROCESS_BBOX_REG_WEIGHT_DIM_DEFAULT_VALUE = 1.0; +const int32_t POSTPROCESS_BBOX_REG_WEIGHT_SIZE_DEFAULT_VALUE = 4; + +// Split +const int32_t SPLIT_INPUT_NUM = 2; +const int32_t SPLIT_DEFAULT_AXIS_VALUE = 1; +const int32_t SPLIT_MIN_OUTPUT_SIZE = 1; + +const uint32_t STRIDEDSLICE_INPUT_NUM = 4; +// Slice +const int32_t SLICE_INPUT_NUM = 3; +const int32_t SLICE_WEIGHT_NUM = 2; + +// GatherNd +const int32_t GATHERND_INPUT_NUM = 2; +// ArgMax +const int32_t ARGMAX_INPUT_NUM = 2; +const int32_t ARGMAX_REAL_INPUT_NUM = 1; + +// HighWay +const int32_t HIGHWAY_INPUT_NUM = 4; +const int32_t HIGHWAY_WORKSPACE_NUM = 1; +// RealDiv +const int32_t REALDIV_INPUT_NUM = 2; + +// Range +const int32_t RANGE_INPUT_NUM = 3; +const int32_t RANGE_OUTPUT_NUM = 1; +const int32_t RANGE_INPUT_DIM_SIZE = 0; + +// Pad +const int32_t PAD_WEIGHT_NUM = 1; +const int32_t PAD_DIM_SIZE = 2; +const int32_t PAD_DIM0 = 4; +const int32_t PAD_DIM1 = 2; +const int32_t PAD_WEIGHT_WITH_CONSTANT_NUM = 2; +const int32_t PAD_CONSTATNT_DEFAULT_VALUE = 0; +const int32_t PAD_PADDINGS_SIZE = 8; + +// Tile +const int32_t TILE_WEIGHT_NUM = 1; +const int32_t TILE_MULTIPLES_DIM_SIZE = 1; + +// DecodeBbox +const int32_t DECODE_BBOX_INPUT_NUM = 2; + +// GenerateRpnProposals +const int32_t GENERATE_RPN_PROPOSAL_INPUT_SIZE = 2; +const int32_t GENERATE_RPN_PROPOSAL_OUTPUT_SIZE = 3; + +// Decode_BBox +const int32_t DECODE_BBOX_INPUT_SIZE = 2; +const int32_t DEFAULT_DECODE_CLIP_VALUE = 0; + +// FastRcnnPredictions +const int32_t FASTRCNN_PREDICTIONS_INPUT_SIZE = 2; +const int32_t FASTRCNN_PREDICTIONS_OUTPUT_SIZE = 4; + +const int32_t CLIP_BOXES_INPUT_NUM = 1; +const int32_t CLIP_BOXES_WEIGHT_SIZE = 1; +const int32_t CLIP_BOXES_WEIGHT_ITEM_SIZE = 2; +const int32_t CLIP_BOXES_OUTPUT_NUM = 1; + +const int32_t FLOORDIV_INPUT_NUM = 2; +// Mean +const int32_t MEAN_WEIGHT_SIZE = 1; +const int32_t MEAN_WEIGHT_DIM_SIZE = 1; +const int32_t MEAN_WEIGHT_DIM = 2; +const int32_t MEAN_FIRST_AXIS = 2; +const int32_t MEAN_SECOND_AXIS = 3; +const int32_t MEAN_STRIDE_PLACE_HOLD = 1; +// Switch +const uint32_t SWITCH_INPUT_NUM = 2; +const uint32_t SWITCH_OUTPUT_NUM = 2; +// Merge +const uint32_t MERGE_INPUT_NUM = 2; +// Greater +const uint32_t GREATER_OUTPUT_NUM = 1; +const uint32_t GREATER_INPUT_NUM = 0; +const uint32_t GREATER_WEIGHT_NUM = 2; + +// Yolo region +const uint32_t YOLO_REGION_OUTPUT_NUM = 3; +const uint32_t YOLO_REGION_WORKSPACE_NUM = 1; +const uint32_t YOLO_REGION_COORDS = 4; +const uint32_t YOLO_REGION_CLASSES = 20; +const uint32_t YOLO_REGION_BOXES = 1; +const bool YOLO_REGION_BACKGROUND = false; +const bool YOLO_REGION_SOFTMAX = false; +const bool YOLO_REGION_SOFTMAX_TREE = false; + +// Yolo detectionoutput +const uint32_t YOLO_DETECTIONOUTPUT_INPUT_SIZE = 4; +const uint32_t YOLO_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const uint32_t YOLO_DETECTION_OUTPUT_WORKSPACE_NUM = 1; +const uint32_t YOLO_DETECTION_OUTPUT_CLASSES = 20; +const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V2 = 5; +const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V3 = 3; +const bool YOLO_DETECTION_OUTPUT_RELATIVE = true; +const float YOLO_DETECTION_OUTPUT_OBJECTNESS_THRESHOLD = 0.5; +const float YOLO_DETECTION_OUTPUT_CLASS_THRESHOLD = 0.5; +const uint32_t YOLO_DETECTION_OUTPUT_POST_TOP_K = UINT_MAX; +const float YOLO_DETECTION_OUTPUT_NMS_THRESHOLD = 0; +const float YOLO_DETECTION_OUTPUT_IOU_THRESHOLD_DECAY = 1.0; +const float YOLO_DETECTION_OUTPUT_COOR_SCALE_FACTOR = 1.0; + +// Reorg +const int32_t REORG_DEFAULT_STRIDE = 2; +const uint32_t REORG_INPUT_COUNT = 1; +// Reshape +const int32_t RESHAPE_INPUT_NUM = 2; +// Maximum +const int32_t MAXIMUM_INPUT_NUM = 2; + +// Spatialtf +const int32_t SPATIALTF_WORKSPACE_NUM = 1; + +const int32_t REVERSE_DEFAULT_AXIS = 1; +// Crop +const int32_t CROP_AXIS = 2; +const int32_t CROP_INPUT_NUM = 2; + +// ConvGradInput +const uint32_t CONVGRADINPUT_WORKSPACE_NUM = 1; +const uint32_t CONVGRADINPUT_INPUT_NUM = 3; + +// RNN +const uint32_t RNN_WORKSPACE_NUM = 1; + +// Cropandresize +const int32_t CROPANDRESIZE_WEIGHT_NUM = 1; +const int32_t CROPANDRESIZE_CROP_DIM_SIZE = 1; +const int32_t CROP_DIM0 = 2; + +// Attention decoder weight index +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENW0 = 0; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_KERNEL = 1; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_KERNEL = 2; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_KERNEL = 3; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_KERNEL = 4; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_KERNEL = 5; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_KERNEL = 6; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_KERNEL = 7; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_BIAS = 8; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_BIAS = 9; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_BIAS = 10; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_BIAS = 11; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_BIAS = 12; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_BIAS = 13; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_BIAS = 14; +const uint32_t ATTENTION_DECODER_WEIGHT_EMBEDDING = 15; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENVA = 16; +const uint32_t ATTENTION_DECODER_WEIGHT_DECODER_INITIAL = 17; +// Attention decoder weight size +const uint32_t ATTENTION_DECODER_WEIGHT_SIZE = 18; + +const uint32_t ATTENTION_DECODER_INPUT_SIZE = 2; +const uint32_t ATTENTION_DECODER_WORKSPACE_NUM = 1; +const uint32_t ATTENTION_DECODER_INPUT_DECODER_INPUTS = 0; +const uint32_t ATTENTION_DECODER_INPUT_DECODER_INITIAL_HIDDEN = 1; + +const int ATTENTION_DECODER_ALGO_NORMAL = 0; +const int ATTENTION_DECODER_SYMBOLS = 10000; +const int ATTENTION_DECODER_EMBEDDING_SIZE = 128; +const int ATTENTION_DECODER_ATTENTION_NUM_HIDDEN = 256; +const int ATTENTION_DECODER_DECODER_NUM_HIDDEN = 128; +const int ATTENTION_DECODER_DECODER_NUM_LAYERS = 2; +const int ATTENTION_DECODER_RNN_UNBIDIRECTIONAL = 0; +const int ATTENTION_DECODER_SEQLEN_VALUE = 57; +const int ATTENTION_DECODER_GRU = 3; + +// Logicaland +const int32_t LOGICAL_AND_INPUT_NUM = 2; +const int32_t EQUAL_INPUT_NUM = 2; + +static const int32_t OP_WEIGHT_MEM_BASE_OFFSET = 512; + +// MultiShape +const uint32_t MULTI_SHAPE_INPUT_NUM = 2; + +// Shufflechannel +const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/op_types.h b/metadef/third_party/graphengine/inc/framework/common/op_types.h new file mode 100644 index 00000000..4555d5c3 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/op_types.h @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_OP_TYPES_H_ +#define INC_FRAMEWORK_COMMON_OP_TYPES_H_ + +#include +#include + +namespace ge { +class OpTypeContainer { + public: + static OpTypeContainer *Instance() { + static OpTypeContainer instance; + return &instance; + } + ~OpTypeContainer() = default; + + void Register(const std::string &op_type) { op_type_list_.insert(op_type); } + + bool IsExisting(const std::string &op_type) { + auto iter_find = op_type_list_.find(op_type); + return iter_find != op_type_list_.end(); + } + + protected: + OpTypeContainer() {} + + private: + std::set op_type_list_; +}; + +class OpTypeRegistrar { + public: + explicit OpTypeRegistrar(const std::string &op_type) { OpTypeContainer::Instance()->Register(op_type); } + ~OpTypeRegistrar() {} +}; + +#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *var_name; + +#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ + const char *var_name = str_name; \ + const OpTypeRegistrar g_##var_name##_reg(str_name); + +#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/scope_guard.h b/metadef/third_party/graphengine/inc/framework/common/scope_guard.h new file mode 100644 index 00000000..001a0e75 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/scope_guard.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ +#define INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ + +#include +#include + +/// Usage: +/// Acquire Resource 1 +/// MAKE_GUARD([&] { Release Resource 1 }) +/// Acquire Resource 2 +// MAKE_GUARD([&] { Release Resource 2 }) +#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) +#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() + +namespace ge { +class ScopeGuard { + public: + // Noncopyable + ScopeGuard(ScopeGuard const &) = delete; + ScopeGuard &operator=(ScopeGuard const &) = delete; + + explicit ScopeGuard(const std::function &on_exit_scope) : on_exit_scope_(on_exit_scope), dismissed_(false) {} + + ~ScopeGuard() { + if (!dismissed_) { + if (on_exit_scope_ != nullptr) { + try { + on_exit_scope_(); + } catch (std::bad_function_call &e) { } + catch (...) { } + } + } + } + + void Dismiss() { dismissed_ = true; } + + private: + std::function on_exit_scope_; + bool dismissed_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/string_util.h b/metadef/third_party/graphengine/inc/framework/common/string_util.h new file mode 100644 index 00000000..de19807c --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/string_util.h @@ -0,0 +1,157 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_STRING_UTIL_H_ +#define INC_FRAMEWORK_COMMON_STRING_UTIL_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace ge { +class StringUtils { + public: + static std::string &Ltrim(std::string &s) { +#if __cplusplus >= 201103L + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); +#else + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), std::not1(std::ptr_fun(std::isspace)))); +#endif + return s; + } + // lint -esym(551,*) + static std::string &Rtrim(std::string &s) { /*lint !e618*/ +#if __cplusplus >= 201103L + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); +#else + (void)s.erase(std::find_if(s.rbegin(), s.rend(), std::not1(std::ptr_fun(std::isspace))).base(), s.end()); +#endif + return s; + } + // lint -esym(551,*) + /// + /// @ingroup domi_common + /// @brief delete spaces at the beginning and end of a string + /// @param [in] string to be trimmed + /// @return string after trim + /// + static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } + + /// + /// @ingroup domi_common + /// @brief string splitting + /// @param [in] str string to be trimmed + /// @param [in] delim separator + /// @return string array after segmentation + /// + static std::vector Split(const std::string &str, char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; + } + /// + /// @ingroup domi_common + /// @brief obtain the file name + /// @param [in] s path name + /// @return file name + /// + static std::string GetFileName(std::string &s) { + if (s.empty()) { + return ""; + } + std::vector files = StringUtils::Split(s, '/'); + + return files.empty() ? "" : files[files.size() - 1]; + } + /// + /// @ingroup domi_common + /// @brief full replacement + /// @link + /// @param [in] str str string to be replaced + /// @param [in] old_value old Characters Before Replacement + /// @param [in] new_value new Characters Before Replacement + /// @return string after replacement + /// + static std::string ReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) { + std::string::size_type cur_pos = 0; + std::string::size_type old_length = old_value.length(); + std::string::size_type new_length = new_value.length(); + // cycle replace + for (; cur_pos != std::string::npos; cur_pos += new_length) { + if ((cur_pos = str.find(old_value, cur_pos)) != std::string::npos) { + (void)str.replace(cur_pos, old_length, new_value); + } else { + break; + } + } + return str; + } + + /// + /// @ingroup domi_common + /// @brief checks whether a character string starts with a character string (prefix) + /// @link + /// @param [in] str string to be compared + /// @param [in] str_x prefix + /// @return if the value is a prefix, true is returned. Otherwise, false is returned + /// + static bool StartWith(const std::string &str, const std::string str_x) { + return ((str.size() >= str_x.size()) && (str.compare(0, str_x.size(), str_x) == 0)); + } + + /// + /// @ingroup domi_common + /// @brief format string + /// @link + /// @param [in] format specifies the character string format + /// @param [in] ... format Filling Content + /// @return formatted string + /// + static std::string FormatString(const char *format, ...) { + const uint32_t MAX_BUFFER_LEN = 1024; // the stack memory plint check result must be less than 1024 + va_list args; + va_start(args, format); + char buffer[MAX_BUFFER_LEN] = {0}; + int32_t ret = vsnprintf_s(buffer, MAX_BUFFER_LEN, MAX_BUFFER_LEN - 1, format, args); + va_end(args); + return ret > 0 ? buffer : ""; + } +}; +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/types.h b/metadef/third_party/graphengine/inc/framework/common/types.h new file mode 100644 index 00000000..22e85e0b --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/types.h @@ -0,0 +1,1106 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_TYPES_H_ +#define INC_FRAMEWORK_COMMON_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/fmk_error_codes.h" +#include "framework/common/fmk_types.h" +#include "framework/common/op_types.h" +#include "register/register_types.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define DOMI_DYNAMIC_CAST static_cast +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#else +#define DOMI_DYNAMIC_CAST static_cast +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif + +namespace ge { +// dump +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_ALL_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; + +// Supported public properties name +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_LOG_PATH; // Log path + +// Profile-related constants +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CCE_PROFILE_ON; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CCE_PROFILE_OFF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OME_PROFILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string CCE_PROFILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string RTS_PROFILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILER_JOBCTX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILER_TARGET_PATH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string RTS_PROFILE_PATH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_KEY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_VALUE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map PROFILE_COMPONENT_MAP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_MODEL_ID; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_FUSION_MODEL_DEF; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int MODEL_MAX_SIZE; // Max size of 2 GB minus 1 byte. +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint64_t FILE_HEADER_MAX_SIZE; // Max size of 3 GB. + +#if !defined(__ANDROID__) && !defined(ANDROID) +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint64_t ALLOC_MEMORY_MAX_SIZE; // Max size of 8 GB. +#else +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint64_t ALLOC_MEMORY_MAX_SIZE; // Max size of 512M. +#endif + +template +static std::pair flip_pair(const std::pair &p) { + return std::pair(p.second, p.first); +} + +template +static std::map flip_map(std::map src) { + std::map dst; + std::transform(src.begin(), src.end(), std::inserter(dst, dst.begin()), flip_pair); + return dst; +} + +REGISTER_OPTYPE_DECLARE(DATA, "Data"); +REGISTER_OPTYPE_DECLARE(AIPPDATA, "AippData"); +REGISTER_OPTYPE_DECLARE(CONVOLUTION, "Convolution"); +REGISTER_OPTYPE_DECLARE(CORRELATION, "Correlation"); +REGISTER_OPTYPE_DECLARE(CORRELATIONV2, "Correlation_V2"); +REGISTER_OPTYPE_DECLARE(DECONVOLUTION, "Deconvolution"); +REGISTER_OPTYPE_DECLARE(POOLING, "Pooling"); +REGISTER_OPTYPE_DECLARE(ELTWISE, "Eltwise"); +REGISTER_OPTYPE_DECLARE(RELU, "ReLU"); +REGISTER_OPTYPE_DECLARE(RELU6, "ReLU6"); +REGISTER_OPTYPE_DECLARE(SIGMOID, "Sigmoid"); +REGISTER_OPTYPE_DECLARE(ABSVAL, "AbsVal"); +REGISTER_OPTYPE_DECLARE(TANH, "TanH"); +REGISTER_OPTYPE_DECLARE(PRELU, "PReLU"); +REGISTER_OPTYPE_DECLARE(BATCHNORM, "BatchNorm"); +REGISTER_OPTYPE_DECLARE(FUSIONBATCHNORM, "FusionBatchNorm"); +REGISTER_OPTYPE_DECLARE(SCALE, "Scale"); +REGISTER_OPTYPE_DECLARE(FULL_CONNECTION, "FullConnection"); +REGISTER_OPTYPE_DECLARE(SOFTMAX, "Softmax"); +REGISTER_OPTYPE_DECLARE(PLUS, "Plus"); +REGISTER_OPTYPE_DECLARE(ACTIVATION, "Activation"); +REGISTER_OPTYPE_DECLARE(FLATTEN, "Flatten"); +REGISTER_OPTYPE_DECLARE(ADD, "Add"); +REGISTER_OPTYPE_DECLARE(SUB, "Sub"); +REGISTER_OPTYPE_DECLARE(MUL, "Mul"); +REGISTER_OPTYPE_DECLARE(MATMUL, "MatMul"); +REGISTER_OPTYPE_DECLARE(RSQRT, "Rsqrt"); +REGISTER_OPTYPE_DECLARE(BIASADD, "BiasAdd"); +REGISTER_OPTYPE_DECLARE(RESHAPE, "Reshape"); +REGISTER_OPTYPE_DECLARE(REFORMAT, "ReFormat"); +REGISTER_OPTYPE_DECLARE(DEPCONVOLUTION, "ConvolutionDepthwise"); +REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout"); +REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); +REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); +REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); +REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling"); +REGISTER_OPTYPE_DECLARE(PROPOSAL, "Proposal"); +REGISTER_OPTYPE_DECLARE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); +REGISTER_OPTYPE_DECLARE(DETECTIONPOSTPROCESS, "Detectpostprocess"); +REGISTER_OPTYPE_DECLARE(LRN, "LRN"); +REGISTER_OPTYPE_DECLARE(TRANSDATA, "TransData"); +REGISTER_OPTYPE_DECLARE(PERMUTE, "Permute"); +REGISTER_OPTYPE_DECLARE(SSDNORMALIZE, "SSDNormalize"); +REGISTER_OPTYPE_DECLARE(SSDPRIORBOX, "SSDPriorBox"); +REGISTER_OPTYPE_DECLARE(NETOUTPUT, "NetOutput"); +REGISTER_OPTYPE_DECLARE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); +REGISTER_OPTYPE_DECLARE(REFINEDETDETECTIONOUTPUT, "RefinedetDetectionOutput"); +REGISTER_OPTYPE_DECLARE(CHANNELAXPY, "ChannelAxpy"); +REGISTER_OPTYPE_DECLARE(PSROIPOOLING, "PSROIPooling"); +REGISTER_OPTYPE_DECLARE(POWER, "Power"); +REGISTER_OPTYPE_DECLARE(POW, "Pow"); +REGISTER_OPTYPE_DECLARE(ROIALIGN, "ROIAlign"); +REGISTER_OPTYPE_DECLARE(PYTHON, "Python"); +REGISTER_OPTYPE_DECLARE(FREESPACEEXTRACT, "FreespaceExtract"); +REGISTER_OPTYPE_DECLARE(SPATIALTF, "SpatialTransform"); +REGISTER_OPTYPE_DECLARE(SHAPE, "Shape"); +REGISTER_OPTYPE_DECLARE(SHAPEN, "ShapeN"); +REGISTER_OPTYPE_DECLARE(ARGMAX, "ArgMax"); +REGISTER_OPTYPE_DECLARE(GATHERND, "GatherNd"); +REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); +REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); +REGISTER_OPTYPE_DECLARE(PACK, "Pack"); +REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); +REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); +REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); +REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); +REGISTER_OPTYPE_DECLARE(UNSQUEEZE, "Unsqueeze"); +REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); +REGISTER_OPTYPE_DECLARE(RANGE, "Range"); +REGISTER_OPTYPE_DECLARE(RPNPROPOSALS, "GenerateRpnProposals"); +REGISTER_OPTYPE_DECLARE(DECODEBBOX, "DecodeBBox"); +REGISTER_OPTYPE_DECLARE(PAD, "Pad"); +REGISTER_OPTYPE_DECLARE(PADV2, "PadV2"); +REGISTER_OPTYPE_DECLARE(MIRRORPAD, "MirrorPad"); +REGISTER_OPTYPE_DECLARE(TILE, "Tile"); +REGISTER_OPTYPE_DECLARE(SIZE, "Size"); +REGISTER_OPTYPE_DECLARE(CLIPBOXES, "Clipboxes"); +REGISTER_OPTYPE_DECLARE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); +REGISTER_OPTYPE_DECLARE(SPLIT, "Split"); +REGISTER_OPTYPE_DECLARE(SPLITV, "SplitV"); +REGISTER_OPTYPE_DECLARE(EXPANDDIMS, "ExpandDims"); +REGISTER_OPTYPE_DECLARE(EMPTY, "Empty"); +REGISTER_OPTYPE_DECLARE(MEAN, "Mean"); +REGISTER_OPTYPE_DECLARE(GREATER, "Greater"); +REGISTER_OPTYPE_DECLARE(SWITCH, "Switch"); +REGISTER_OPTYPE_DECLARE(SWITCHN, "SwitchN"); +REGISTER_OPTYPE_DECLARE(REFSWITCH, "RefSwitch"); +REGISTER_OPTYPE_DECLARE(MERGE, "Merge"); +REGISTER_OPTYPE_DECLARE(REFMERGE, "RefMerge"); +REGISTER_OPTYPE_DECLARE(ENTER, "Enter"); +REGISTER_OPTYPE_DECLARE(REFENTER, "RefEnter"); +REGISTER_OPTYPE_DECLARE(LOOPCOND, "LoopCond"); +REGISTER_OPTYPE_DECLARE(NEXTITERATION, "NextIteration"); +REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); +REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); +REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); +REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); +REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); +REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); +REGISTER_OPTYPE_DECLARE(_IF, "_If"); +REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); +REGISTER_OPTYPE_DECLARE(IF, "If"); +REGISTER_OPTYPE_DECLARE(CASE, "Case"); +REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); +REGISTER_OPTYPE_DECLARE(WHILE, "While"); +REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); +REGISTER_OPTYPE_DECLARE(FOR, "For"); +REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); +REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); +REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); +REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); +REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); +REGISTER_OPTYPE_DECLARE(CAST, "Cast"); +REGISTER_OPTYPE_DECLARE(REGION, "Region"); +REGISTER_OPTYPE_DECLARE(YOLO, "Yolo"); +REGISTER_OPTYPE_DECLARE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); +REGISTER_OPTYPE_DECLARE(FILL, "Fill"); +REGISTER_OPTYPE_DECLARE(RANK, "Rank"); +REGISTER_OPTYPE_DECLARE(REVERSE, "Reverse"); +REGISTER_OPTYPE_DECLARE(UNPACK, "Unpack"); +REGISTER_OPTYPE_DECLARE(YOLO2REORG, "Yolo2Reorg"); +REGISTER_OPTYPE_DECLARE(REDUCESUM, "ReduceSum"); +REGISTER_OPTYPE_DECLARE(SUM, "Sum"); +REGISTER_OPTYPE_DECLARE(CONSTANT, "Const"); +REGISTER_OPTYPE_DECLARE(RESIZEBILINEAR, "ResizeBilinear"); +REGISTER_OPTYPE_DECLARE(RESIZEBILINEARGRAD, "ResizeBilinearGrad"); +REGISTER_OPTYPE_DECLARE(MAXIMUM, "Maximum"); +REGISTER_OPTYPE_DECLARE(FRAMEWORKOP, "FrameworkOp"); +REGISTER_OPTYPE_DECLARE(ARG, "_Arg"); +REGISTER_OPTYPE_DECLARE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); +REGISTER_OPTYPE_DECLARE(LSTM, "LSTM"); +REGISTER_OPTYPE_DECLARE(HIGHWAY, "HighWay"); +REGISTER_OPTYPE_DECLARE(RNN, "RNN"); +REGISTER_OPTYPE_DECLARE(ATTENTIONDECODER, "AttentionDecoder"); +REGISTER_OPTYPE_DECLARE(LOGICAL_NOT, "LogicalNot"); +REGISTER_OPTYPE_DECLARE(LOGICAL_AND, "LogicalAnd"); +REGISTER_OPTYPE_DECLARE(LOGICAL_OR, "LogicalOr"); +REGISTER_OPTYPE_DECLARE(EQUAL, "Equal"); +REGISTER_OPTYPE_DECLARE(NOTEQUAL, "NotEqual"); +REGISTER_OPTYPE_DECLARE(INTERP, "Interp"); +REGISTER_OPTYPE_DECLARE(SHUFFLECHANNEL, "ShuffleChannel"); +REGISTER_OPTYPE_DECLARE(AIPP, "Aipp"); +REGISTER_OPTYPE_DECLARE(MULTISHAPE, "MultiShape"); +REGISTER_OPTYPE_DECLARE(RECIPROCAL, "Reciprocal"); +REGISTER_OPTYPE_DECLARE(SELU, "Selu"); +REGISTER_OPTYPE_DECLARE(ELU, "Elu"); +REGISTER_OPTYPE_DECLARE(ACOSH, "Acosh"); +REGISTER_OPTYPE_DECLARE(ASINH, "Asinh"); +REGISTER_OPTYPE_DECLARE(MINIMUM, "Minimum"); +REGISTER_OPTYPE_DECLARE(CLIP, "Clip"); +REGISTER_OPTYPE_DECLARE(L2NORMALIZE, "L2Normalize"); +REGISTER_OPTYPE_DECLARE(CROPANDRESIZE, "CropAndResize"); +REGISTER_OPTYPE_DECLARE(UNUSEDCONST, "UnusedConst"); +REGISTER_OPTYPE_DECLARE(SPARSETODENSE, "SparseToDense"); +REGISTER_OPTYPE_DECLARE(NONMAXSUPPRESSION, "NonMaxSuppression"); +REGISTER_OPTYPE_DECLARE(TOPKV2, "TopKV2"); +REGISTER_OPTYPE_DECLARE(INVERTPERMUTATION, "InvertPermutation"); +REGISTER_OPTYPE_DECLARE(MULTINOMIAL, "Multinomial"); +REGISTER_OPTYPE_DECLARE(REVERSESEQUENCE, "ReverseSequence"); +REGISTER_OPTYPE_DECLARE(REDUCEPROD, "ReduceProd"); +REGISTER_OPTYPE_DECLARE(REDUCEMAX, "ReduceMax"); +REGISTER_OPTYPE_DECLARE(REDUCEMIN, "ReduceMin"); +REGISTER_OPTYPE_DECLARE(EXTRACTIMAGEPATCHES, "ExtractImagePatches"); +REGISTER_OPTYPE_DECLARE(SQRT, "Sqrt"); +REGISTER_OPTYPE_DECLARE(REDUCEALL, "ReduceAll"); +REGISTER_OPTYPE_DECLARE(RESIZENEARESTNEIGHBOR, "ResizeNearestNeighbor"); +REGISTER_OPTYPE_DECLARE(SPACETOBATCHND, "SpaceToBatchND"); +REGISTER_OPTYPE_DECLARE(BATCHTOSPACEND, "BatchToSpaceND"); +REGISTER_OPTYPE_DECLARE(ASSERT, "Assert"); +REGISTER_OPTYPE_DECLARE(GREATEREQUAL, "GreaterEqual"); +REGISTER_OPTYPE_DECLARE(FLOOR, "Floor"); +REGISTER_OPTYPE_DECLARE(RANDOMUNIFORM, "RandomUniform"); +REGISTER_OPTYPE_DECLARE(BATCHMATMUL, "BatchMatMul"); +REGISTER_OPTYPE_DECLARE(LESSEQUAL, "LessEqual"); +REGISTER_OPTYPE_DECLARE(ONEHOT, "OneHot"); +REGISTER_OPTYPE_DECLARE(LAYERNORM, "LayerNorm"); +REGISTER_OPTYPE_DECLARE(SPACETODEPTH, "SpaceToDepth"); +REGISTER_OPTYPE_DECLARE(DEPTHTOSPACE, "DepthToSpace"); +REGISTER_OPTYPE_DECLARE(RINT, "Rint"); +REGISTER_OPTYPE_DECLARE(ATAN, "Atan"); +REGISTER_OPTYPE_DECLARE(ATAN2, "Atan2"); +REGISTER_OPTYPE_DECLARE(ATANH, "Atanh"); +REGISTER_OPTYPE_DECLARE(ACOS, "Acos"); +REGISTER_OPTYPE_DECLARE(ASIN, "Asin"); +REGISTER_OPTYPE_DECLARE(NEG, "Neg"); +REGISTER_OPTYPE_DECLARE(LOG, "Log"); +REGISTER_OPTYPE_DECLARE(TAN, "Tan"); +REGISTER_OPTYPE_DECLARE(ROUND, "Round"); +REGISTER_OPTYPE_DECLARE(UPSAMPLE, "Upsample"); +REGISTER_OPTYPE_DECLARE(FLOORMOD, "FloorMod"); +REGISTER_OPTYPE_DECLARE(LESS, "Less"); +REGISTER_OPTYPE_DECLARE(ZEROSLIKE, "ZerosLike"); +REGISTER_OPTYPE_DECLARE(EXP, "Exp"); +REGISTER_OPTYPE_DECLARE(WHERE, "Where"); +REGISTER_OPTYPE_DECLARE(FAKEQUANTWITHMINMAXVARS, "FakeQuantWithMinMaxVars"); +REGISTER_OPTYPE_DECLARE(SOFTPLUS, "Softplus"); +REGISTER_OPTYPE_DECLARE(SOFTSIGN, "Softsign"); +REGISTER_OPTYPE_DECLARE(COSH, "Cosh"); +REGISTER_OPTYPE_DECLARE(SINH, "Sinh"); +REGISTER_OPTYPE_DECLARE(RETINAMULTIANCHORS, "RetinaMultiAnchor"); +REGISTER_OPTYPE_DECLARE(SQUAREDDIFFERENCE, "SquaredDifference"); +REGISTER_OPTYPE_DECLARE(REQUIREDSPACETOBATCHPADDINGS, "RequiredSpaceToBatchPaddings"); // for retinanet scope fusion +REGISTER_OPTYPE_DECLARE(SSDPOSTPROCESSOR, "SSDPostProcessor"); +REGISTER_OPTYPE_DECLARE(SSDANCHORGENERATOR, "SSDAnchorGenerator"); +REGISTER_OPTYPE_DECLARE(RETINANETBOXES, "RetinanetBoxes"); +REGISTER_OPTYPE_DECLARE(RETINANETCLIPPEDBOXES, "RetinanetClippedBoxes"); +REGISTER_OPTYPE_DECLARE(RETINANETFILTEREDDETECTIONS, "RetinanetFilteredDetections"); +REGISTER_OPTYPE_DECLARE(RETINANETPOSTPROCESSOR, "RetinanetPostProcessor"); +REGISTER_OPTYPE_DECLARE(RETINANETANCHORS, "RetinanetAnchors"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNMAP, "FasterRCNNMap"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNMAP1, "FasterRCNNMap1"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNSECONDSTAGEPOSTPROCESSOR, "FasterRCNNSecondStagePostprocessor"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNROIINTERPOOLING, "FasterRCNNROIInterPooling"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNFIRSTSTAGEPOSTPROCESSOR, "FasterRCNNFirstStagePostprocessor"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNGRIDANCHORGENERATOR, "FasterRCNNGridAnchorGenerator"); +REGISTER_OPTYPE_DECLARE(ROIINTERPOOLING, "ROIInterPooling"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNCLIPTOWINDOW, "FasterRCNNClipToWindow"); +REGISTER_OPTYPE_DECLARE(EMBEDLOOKUP, "EmbedLookup"); +REGISTER_OPTYPE_DECLARE(HASHLOOKUP, "HashLookup"); +REGISTER_OPTYPE_DECLARE(LSH_PROJ, "LshProject"); +REGISTER_OPTYPE_DECLARE(SVDF, "SVDF"); +REGISTER_OPTYPE_DECLARE(IDENTITY, "Identity"); +REGISTER_OPTYPE_DECLARE(PLACEHOLDERWITHDEFAULT, "PlaceholderWithDefault"); +REGISTER_OPTYPE_DECLARE(IDENTITYN, "IdentityN"); +REGISTER_OPTYPE_DECLARE(GETSPAN, "GetSpan"); +REGISTER_OPTYPE_DECLARE(STOPGRADIENT, "StopGradient"); +REGISTER_OPTYPE_DECLARE(PREVENTGRADIENT, "PreventGradient"); +REGISTER_OPTYPE_DECLARE(GUARANTEECONST, "GuaranteeConst"); +REGISTER_OPTYPE_DECLARE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); +REGISTER_OPTYPE_DECLARE(BROADCASTARGS, "BroadcastArgs"); +REGISTER_OPTYPE_DECLARE(CONCATV2, "ConcatV2"); +REGISTER_OPTYPE_DECLARE(CONCATOFFSET, "ConcatOffset"); +REGISTER_OPTYPE_DECLARE(LESSEQUAL, "LessEqual"); +REGISTER_OPTYPE_DECLARE(SELECT, "Select"); +REGISTER_OPTYPE_DECLARE(CONFUSIONMATRIX, "ConfusionMatrix"); +REGISTER_OPTYPE_DECLARE(PLACEHOLDER, "PlaceHolder"); +REGISTER_OPTYPE_DECLARE(END, "End"); +REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); +REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); +REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); +REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") +REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); +REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast"); + +// ANN dedicated operator +REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); +REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); +REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); +REGISTER_OPTYPE_DECLARE(ANN_FULLCONNECTION, "AnnFullConnection"); +REGISTER_OPTYPE_DECLARE(ANN_NETOUTPUT, "AnnNetOutput"); +REGISTER_OPTYPE_DECLARE(ANN_DATA, "AnnData"); +REGISTER_OPTYPE_DECLARE(ANN_RESHAPE, "AnnReshape"); +REGISTER_OPTYPE_DECLARE(ANN_ADD, "AnnAdd"); +REGISTER_OPTYPE_DECLARE(ANN_MUL, "AnnMul"); +REGISTER_OPTYPE_DECLARE(ANN_SUB, "AnnSub"); +REGISTER_OPTYPE_DECLARE(ANN_DIV, "AnnDiv"); +REGISTER_OPTYPE_DECLARE(ANN_DEQUANTIZE, "AnnDequant"); +REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); +REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); +REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); + +// Training operator +REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); +REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); +REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); +REGISTER_OPTYPE_DECLARE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); +REGISTER_OPTYPE_DECLARE(FUSEDBATCHNORM, "FusedBatchNorm"); +REGISTER_OPTYPE_DECLARE(BIASADDGRAD, "BiasAddGrad"); +REGISTER_OPTYPE_DECLARE(ACTIVATIONGRAD, "ReluGrad"); +REGISTER_OPTYPE_DECLARE(MAXPOOLWITHARGMAX, "MaxPoolWithArgmax"); +REGISTER_OPTYPE_DECLARE(MAXPOOLGRADWITHARGMAX, "MaxPoolGradWithArgmax"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPYWITHLOGITS, "SparseSoftmaxCrossEntropyWithLogits"); +REGISTER_OPTYPE_DECLARE(SNAPSHOT, "Snapshot"); +REGISTER_OPTYPE_DECLARE(LAYERNORM, "LayerNorm"); +REGISTER_OPTYPE_DECLARE(HUBERLOSSGRAD, "HuberLossGrad"); +REGISTER_OPTYPE_DECLARE(HUBERLOSS, "HuberLoss"); +REGISTER_OPTYPE_DECLARE(NEGATIVE, "Negative"); +REGISTER_OPTYPE_DECLARE(SSDCAST, "SSDCast"); +REGISTER_OPTYPE_DECLARE(SSDSQUEEZEFUSION, "SsdSqueezeFusion"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPY, "SsdSparseSoftmaxCrossEntropy"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPYGRAD, "SsdSparseSoftmaxCrossEntropyGrad"); +REGISTER_OPTYPE_DECLARE(CONCATFIVE2FOUR, "ConcatFive2Four"); +REGISTER_OPTYPE_DECLARE(CONCATFOUR2FIVE, "ConcatFour2Five"); +REGISTER_OPTYPE_DECLARE(SSDREALDIVTILEMUL, "SSDRealdivTileMul"); +REGISTER_OPTYPE_DECLARE(SSDSUMMULREALDIVMEAN, "SSDSumMulRealdivMean"); + +REGISTER_OPTYPE_DECLARE(MEANGRAD, "MeanGrad"); +REGISTER_OPTYPE_DECLARE(TRANSLATE, "Translate"); +REGISTER_OPTYPE_DECLARE(ADDN, "AddN"); +REGISTER_OPTYPE_DECLARE(L2LOSS, "L2Loss"); +REGISTER_OPTYPE_DECLARE(MULTIPLY, "Multiply"); +REGISTER_OPTYPE_DECLARE(RELU6GRAD, "Relu6Grad"); +REGISTER_OPTYPE_DECLARE(AVGPOOLGRAD, "AvgPoolGrad"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DBACKPROPFILTER, "DepthwiseConv2dNativeBackpropFilter"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DBACKPORPINPUT, "DepthwiseConv2dNativeBackpropInput"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DFORWARDNATIVE, "DepthwiseConv2dNative"); +REGISTER_OPTYPE_DECLARE(DROPOUTGRAD, "DropOutGrad"); +REGISTER_OPTYPE_DECLARE(APPLYRMSPROPMIXEDPRECISION, "apply_rms_prop_mixed_precision"); +REGISTER_OPTYPE_DECLARE(APPLYRMSPROP, "ApplyRMSProp"); +REGISTER_OPTYPE_DECLARE(LARS, "Lars"); +REGISTER_OPTYPE_DECLARE(DYNAMICSTITCH, "DynamicStitch"); + +// Variable sink related +REGISTER_OPTYPE_DECLARE(VARIABLEV2, "VariableV2"); +REGISTER_OPTYPE_DECLARE(VARHANDLEOP, "VarHandleOp"); +REGISTER_OPTYPE_DECLARE(TEMPORARYVARIABLE, "TemporaryVariable"); +REGISTER_OPTYPE_DECLARE(DESTROYTEMPORARYVARIABLE, "DestroyTemporaryVariable"); +REGISTER_OPTYPE_DECLARE(VARIABLE, "Variable"); + +REGISTER_OPTYPE_DECLARE(READVARIABLEOP, "ReadVariableOp"); + +REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); +REGISTER_OPTYPE_DECLARE(ISVARIABLEINITIALIZED, "IsVariableInitialized"); + +REGISTER_OPTYPE_DECLARE(ASSIGN, "Assign"); +REGISTER_OPTYPE_DECLARE(ASSIGNVARIABLEOP, "AssignVariableOp"); + +REGISTER_OPTYPE_DECLARE(ASSIGNADD, "AssignAdd"); +REGISTER_OPTYPE_DECLARE(ASSIGNADDVARIABLEOP, "AssignAddVariableOp"); + +REGISTER_OPTYPE_DECLARE(ASSIGNSUB, "AssignSub"); +REGISTER_OPTYPE_DECLARE(ASSIGNSUBVARIABLEOP, "AssignSubVariableOp"); + +REGISTER_OPTYPE_DECLARE(APPLYMOMENTUM, "ApplyMomentum"); +REGISTER_OPTYPE_DECLARE(RESOURCEAPPLYMOMENTUM, "ResourceApplyMomentum"); +REGISTER_OPTYPE_DECLARE(SGD, "SGD"); +REGISTER_OPTYPE_DECLARE(NOOP, "NoOp"); +REGISTER_OPTYPE_DECLARE(LAYERNORMGRAD, "LayerNormGrad"); + +REGISTER_OPTYPE_DECLARE(SQUARE, "Square"); +REGISTER_OPTYPE_DECLARE(HCOMBROADCAST, "HcomBroadcast"); +REGISTER_OPTYPE_DECLARE(HCOMALLGATHER, "HcomAllGather"); +REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); +REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); +REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); +REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); + +REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); +REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); +REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); +REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); +REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); +REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); +REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); +REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); +REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); +REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); +REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); +REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); +REGISTER_OPTYPE_DECLARE(MODELEXIT, "ModelExit"); +REGISTER_OPTYPE_DECLARE(SEND, "Send"); +REGISTER_OPTYPE_DECLARE(RECV, "Recv"); +REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); + +REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); +REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); +REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); +REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); +REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); + +REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); + +REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); +REGISTER_OPTYPE_DECLARE(ACCUMULATE_N_V2, "AccumulateNV2"); +REGISTER_OPTYPE_DECLARE(ACOS_GRAD, "AcosGrad"); +REGISTER_OPTYPE_DECLARE(ACOSH_GRAD, "AcoshGrad"); +REGISTER_OPTYPE_DECLARE(ANY, "Any"); +REGISTER_OPTYPE_DECLARE(APPROXIMATE_EQUAL, "ApproximateEqual"); +REGISTER_OPTYPE_DECLARE(ASIN_GRAD, "AsinGrad"); +REGISTER_OPTYPE_DECLARE(ASINH_GRAD, "AsinhGrad"); +REGISTER_OPTYPE_DECLARE(ATAN_GRAD, "AtanGrad"); +REGISTER_OPTYPE_DECLARE(BROADCAST_TO, "BroadcastTo"); +REGISTER_OPTYPE_DECLARE(ELU_GRAD, "EluGrad"); +REGISTER_OPTYPE_DECLARE(ADD_V2, "AddV2"); +REGISTER_OPTYPE_DECLARE(DATAFORMATDIMMAP, "DataFormatDimMap"); +REGISTER_OPTYPE_DECLARE(DATAFORMATVECPERMUTE, "DataFormatVecPermute"); +REGISTER_OPTYPE_DECLARE(BESSELI0e, "BesselI0e"); +REGISTER_OPTYPE_DECLARE(BESSELI1e, "BesselI1e"); +REGISTER_OPTYPE_DECLARE(DEQUANTIZE, "Dequantize"); +REGISTER_OPTYPE_DECLARE(APPLYADADELTA, "ApplyAdadelta"); +REGISTER_OPTYPE_DECLARE(APPLYADAGRAD, "ApplyAdagrad"); +REGISTER_OPTYPE_DECLARE(APPLYADAGRADDA, "ApplyAdagradDA"); +REGISTER_OPTYPE_DECLARE(APPLYADAM, "ApplyAdam"); +REGISTER_OPTYPE_DECLARE(APPLYADAMAX, "ApplyAdaMax"); +REGISTER_OPTYPE_DECLARE(APPLYADDSIGN, "ApplyAddSign"); +REGISTER_OPTYPE_DECLARE(APPLYCENTEREDRMSPROP, "ApplyCenteredRMSProp"); +REGISTER_OPTYPE_DECLARE(APPLYFTRL, "ApplyFtrl"); +REGISTER_OPTYPE_DECLARE(APPLYFTRLV2, "ApplyFtrlv2"); +REGISTER_OPTYPE_DECLARE(APPLYGRADIENTDESCENT, "ApplyGradientDescent"); +REGISTER_OPTYPE_DECLARE(APPLYPOWERSIGN, "ApplyPowerSign"); +REGISTER_OPTYPE_DECLARE(APPLYPROXIMALADAGRAD, "ApplyProximalAdagrad"); +REGISTER_OPTYPE_DECLARE(APPLYPROXIMALGRADIENTDESCENT, "ApplyProximalGradientDescent"); + +REGISTER_OPTYPE_DECLARE(FOCAL_LOSS, "FocalLoss"); +REGISTER_OPTYPE_DECLARE(FOCAL_LOSS_GRAD, "FocalLossGrad"); +REGISTER_OPTYPE_DECLARE(SMOOTHL1_LOSS, "SmoothL1Loss"); +REGISTER_OPTYPE_DECLARE(SMOOTHL1_LOSS_grad, "SmoothL1LossGrad"); +REGISTER_OPTYPE_DECLARE(REDUCEMEAN, "ReduceMean"); +REGISTER_OPTYPE_DECLARE(CONCAT_V2, "ConcatV2"); +REGISTER_OPTYPE_DECLARE(ONEHOT_V2, "OneHotV2"); +REGISTER_OPTYPE_DECLARE(SLICE_V2, "SliceV2"); +REGISTER_OPTYPE_DECLARE(TILE_V2, "TileV2"); +REGISTER_OPTYPE_DECLARE(SUM_V2, "SumV2"); +// Common operator type when operators have the same name +REGISTER_OPTYPE_DECLARE(DETECTIONOUTPUT, "DetectionOutput"); + +// custom operator +REGISTER_OPTYPE_DECLARE(CUSTOMOP, "CustomOp"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NCHW, "CustomOpNchw"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NHWC, "CustomOpNhwc"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NC1HWC0, "CustomOpNc1hwc0"); + +// Depthwise 4d_2_6d,6d_2_4d +REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT4D26D, "depthwise_weight_4d_2_6d"); +REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d"); + +REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad"); +REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad"); + +// Horovod operator +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); +REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); + +// aicpu op for online_infer dynamic_dims +REGISTER_OPTYPE_DECLARE(GETDYNAMICDIMS, "GetDynamicDims"); + +enum InputMode { INPUT = 0, CONST_INPUT}; + +// Definition of the processing status enum of the process module +enum ModelProcessState { + INIT_STATE = 0, // init status + WAIT_EVENT_STATE, // Wait for the event status + IND_RSLT_STATE, // The model execution result is being output to the high level + STOPPED_STATE, // Model execution completed. The model enters this state after Model Manager::Stop + RESERVED_STATE, // reserved +}; + +// Indicates the enun definition of the execution mode of the access module +enum SysMode { + INFERENCE = 0, // Normal, that is, Inference mode + DEBUG, // Debug mode + TIME, // Model execution time mode, including the execution time of each OP + STOP, // STOP mode + RESET, // RESET mode + PERFORMANCE, // Impact of enabling the performance model: 1. The input data of the model is considered ready and does + // not need to be converted + ANDROID_DEBUG, // Exports Android platform computing data + RESERVED, // reserved +}; + +// @brief encryption type of the model file +enum ModelEncryptType { + UNENCRYPTED, // not encrypted + ENCRYPTED // encrypted +}; + +/// +/// @brief signature verification +/// +enum ModelCheckType { + CHECK, // signature verification + UNCHECK // no verification +}; + +/// +/// @brief dynamic input type +/// +enum DynamicInputType { + FIXED = 0, // default mode + DYNAMIC_BATCH = 1, + DYNAMIC_IMAGE = 2, + DYNAMIC_DIMS = 3 +}; + +/// +/// @brief magic number of the model file +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_MAGIC_NUM; + +/// +/// @brief model header length +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_HEAD_LEN; + +/// +/// @brief model name length +/// +static constexpr uint32_t MODEL_NAME_LENGTH = 32; + +/// +/// @brief length of user-defined information +/// +static constexpr uint32_t USER_DEFINE_INFO_LENGTH = 32; + +/// +/// @brief length of the model file signature +/// +static constexpr uint32_t MODEL_FILE_CHECKSUM_LENGTH = 64; + +/// +/// @brief length of the reserved field in the model file header +/// +static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 79; + +/// +/// @ingroup domi_omg +/// @brief INPUT node type +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string INPUT_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMMY_DATA; + +/// +/// @ingroup domi_omg +/// @brief AIPP flag, indicating the aipp conv operator +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_CONV_FLAG; + +/// +/// @ingroup domi_omg +/// @brief AIPP flag, indicating the aipp data operator +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_DATA_FLAG; + +// flag of the Data operator, indicating that the input will be input to the dynamic AIPP operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string INPUT_TO_DYNAMIC_AIPP; + +// records the W dimension of the model input corresponding to the dynamic AIPP +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_RELATED_DATA_DIM_W; + +// H dimension of the model input corresponding to the dynamic AIPP +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_RELATED_DATA_DIM_H; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DATA_TYPE; + +// DATA Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_DATA_TYPE; + +// framework Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string FRAMEWORK_OP_TYPE; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_DATA_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_NETOUTPUT_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_DEPTHCONV_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_CONV_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_FC_TYPE; +// convolution node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_NET_OUTPUT; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; + +// convolution node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; +// adds a convolutional node name for the hard AIPP +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_CONV_OP_NAME; +// delimiter of operator configuration items +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_CONF_DELIMITER; + +// op attr name +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_VALUE1; + +// op attr name, used to 6d_2_4d C channel +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_INPUT_CVALUE; + +// op attr name +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_VALUE1; + +// alpha default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float ALPHA_DEFAULT_VALUE; + +// beta default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float BETA_DEFAULT_VALUE; + +// coef default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float COEF_DEFAULT_VALUE; + +// coef value of Relu6 +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float RELU6_COEF; + +// stride default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STRIDE_DEFAULT_VALUE; + +// pad default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t PAD_DEFAULT_VALUE; + +// dilation default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int DILATION_DEFAULT_VALUE; + +// kernel default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t KERNEL_DEFAULT_VALUE; + +// default conv Group Size +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t DEFAULT_CONV_GROUP; + +// default deconv adj +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t DEFAULT_DECONV_ADJ; + +// indicate num 1 +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NUM_ONE; + +// dim default size value +static const int32_t DIM_DEFAULT_SIZE = 4; + +// the shape of c must be the mutiply of 16 for depthwise +static const uint32_t DEPTHWISE_DIM_C_BASE_NUM = 16; + +// C1HWNCoC0 dim size +static const int32_t DIM_C1HWNCoC0_SIZE = 6; +// C1HWNCoC0 C0 value +static const int C1HWCOC_C0_VALUE = 16; +// spatial default dim size +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t SPATIAL_DIM_DEFAULT_SIZE; + +// dim extension default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t DIM_DEFAULT_VALUE; + +// the first item in the weight list of opdef is filter +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t WEIGHT_FILTER_INDEX; + +// the second item in the weight list of opdef is bias. +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t WEIGHT_BIAS_INDEX; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int32_t TENSOR_ND_SUPPORT_SIZE; + +// default NCHW index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_N; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_W; + +// default C1HWNCoC0 index +static const uint32_t C1HWNCoC0_DIM_C1 = 0; +static const uint32_t C1HWNCoC0_DIM_H = 1; +static const uint32_t C1HWNCoC0_DIM_W = 2; +static const uint32_t C1HWNCoC0_DIM_N = 3; +static const uint32_t C1HWNCoC0_DIM_Co = 4; +static const uint32_t C1HWNCoC0_DIM_C0 = 5; + +// default KCHW index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t KCHW_DIM_K; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t KCHW_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t KCHW_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t KCHW_DIM_W; + +// default HWCK index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWCK_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWCK_DIM_W; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWCK_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWCK_DIM_K; + +// default NHWC index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_N; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_W; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_C; + +// default CHWN index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHWN_DIM_N; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHWN_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHWN_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHWN_DIM_W; + +// default CHW index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHW_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHW_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t CHW_DIM_W; + +// default HWC index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWC_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWC_DIM_W; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t HWC_DIM_C; +// default Pad index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t PAD_H_HEAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t PAD_H_TAIL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t PAD_W_HEAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t PAD_W_TAIL; + +// default window index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t WINDOW_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t WINDOW_W; + +// default stride index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STRIDE_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STRIDE_W; + +// default dilation index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t DILATION_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t DILATION_W; + +// the num of XRBG channel +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t XRGB_CHN_NUM; + +// default tensor format +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int DEFAULT_FORMAT; + +// default global pooling +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const bool DEFAULT_GLOBAL_POOLING; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_VERSION; // model version 1.0 + +// Number of inputs of the Eltwise operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int ELTWISE_MIN_INPUT_SIZE; + +// flowctrl +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_STREAM_SWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_STREAM_ACTIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_COND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_INCREMENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_RESETVALUE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_ASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_ATOMIC_ADDR_CLEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t TRUE_STREAM_ID; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STREAM_SWITCH_INPUT_NUM; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_GLOBAL_STEP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD; + +static const int PLATFORM_VERSION_LEN = 20; + +// Definition of the file header of the model file +struct ModelFileHeader { + uint32_t magic = MODEL_FILE_MAGIC_NUM; // magic number of DOMI + uint32_t headsize = MODEL_FILE_HEAD_LEN; // length of the model header. The value is fixed at 256 + uint32_t version = MODEL_VERSION; // version 1.0 + uint8_t checksum[MODEL_FILE_CHECKSUM_LENGTH] = {0}; // signature + uint32_t length = 0; // Ciphertext length. In the non-encryption model, the length is the plaintext length. + uint8_t is_encrypt = ModelEncryptType::UNENCRYPTED; // whether encrypted 0:not encrypt, 1:encrypt + uint8_t is_checksum = ModelCheckType::CHECK; // whether to check the checksum + uint8_t modeltype = 0; // 0:IR model 1:standard model 2: OM Tiny model + uint8_t genmode = 0; // 0:offline generate 1:online generate + uint8_t name[MODEL_NAME_LENGTH] = {0}; // Model name, which contains 32 characters + uint32_t ops = 0; // Computing power (Kops) + uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH] = {0}; // User-defined information. The value contains 32 characters + uint32_t om_ir_version = 0; + uint8_t platform_version[PLATFORM_VERSION_LEN] = {0}; + uint8_t platform_type = {0}; + uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0}; // Reserved field 79 +}; + +static constexpr uint8_t TARGET_TYPE_LTTE_8BIT = 0; +static constexpr uint8_t TARGET_TYPE_MINI_8BIT = 1; +static constexpr uint8_t TARGET_TYPE_TINY_8BIT = 2; + +static constexpr int32_t PARTITION_TYPE_MODEL_DEF = 0; +static constexpr int32_t PARTITION_TYPE_WEIGHTS = 1; +static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; + +// number of partitions in the current model +static constexpr uint32_t PARTITION_SIZE = 5; + +enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS, CUST_AICPU_KERNELS }; + +struct ModelPartitionMemInfo { + ModelPartitionType type; + uint32_t mem_offset; + uint32_t mem_size; +}; + +struct ModelPartitionTable { + uint32_t num; + ModelPartitionMemInfo partition[0]; +}; + +#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) + +static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success + +// Filter format +typedef enum tagDomiFilterFormat { + DOMI_FILTER_KCHW, // KCHW + DOMI_FILTER_HWCK, // HWCK + DOMI_FILTER_RESERVED +} domiFilterFormat_t; + +// Const data trans type +typedef enum tagDomiConstDataTransType { + DOMI_CONST_DATA_NOT_CHANGE = 0, // No action is required + DOMI_CONST_DATA_TRANS_MATMUL, // The const input to MatMul and needs to be transposed + DOMI_CONST_DATA_RESERVED +} domiConstDataTransType_t; + +// mode of activation +typedef enum tagDomiActivationMode { + DOMI_ACTIVATION_SIGMOID = 0, // sigmoid + DOMI_ACTIVATION_RELU, // ReLU + DOMI_ACTIVATION_TANH, // tanh + DOMI_ACTIVATION_CLIPPED_RELU, // clipped ReLU + DOMI_ACTIVATION_ELU, // ELU + DOMI_ACTIVATION_LEAKY_RELU, + DOMI_ACTIVATION_ABS, // Abs + DOMI_ACTIVATION_RELU1, // relu1 + DOMI_ACTIVATION_SOFTSIGN, // softsign + DOMI_ACTIVATION_SOFTPLUS, // softplus + DOMI_ACTIVATION_HARDSIGMOID, // hardsigmoid + DOMI_ACTIVATION_THRESHOLD_RELU, // threshold + DOMI_ACTIVATION_SELU, // selu + DOMI_ACTIVATION_LINEAR, // linear + DOMI_ACTIVATION_RESERVED +} domiActivationMode_t; + +// mode of batchnorm +typedef enum tagDomiBatchNormMode { + DOMI_BATCHNORM_PER_ACTIVATION = 0, // bnScale, bnBias tensor dims are 1xCxHxW + DOMI_BATCHNORM_SPATIAL, // bnScale, bnBias tensor dims are 1xCx1x1 + DOMI_BATCHNORM_RESERVED +} domiBatchNormMode_t; + +// eltwise mode +typedef enum tagDomiEltwiseMode { + DOMI_ELTWISE_PROD = 0, // prod + DOMI_ELTWISE_SUM, // sum + DOMI_ELTWISE_MAX, // max + DOMI_ELTWISE_RESERVED +} domiEltwiseMode_t; + +// mode of padding +typedef enum tagDomiPaddingMode { + DOMI_PADDING_CEIL = 0, // Default padding mode + DOMI_PADDING_DIRECTASSIGN, // Default padding mode: NOTSET + DOMI_PADDING_VALID, // VALID padding mode + DOMI_PADDING_SAME, // Padding values of 0 are always used + DOMI_PADDING_CEIL_NEW, // Padding values of 0 are always used + DOMI_PADDING_VALID_NEW, // Padding values of 0 are always used + DOMI_PADDING_SAME_NEW, // Padding values of 0 are always used + DOMI_PADDING_RESERVED +} domiPaddingMode_t; + +// algorithm of convolution forward +typedef enum tagDomiConvolutionFwdAlgo { + DOMI_CONVOLUTION_FWD_ALGO_GEMM = 0, // matrix gemm algo + DOMI_CONVOLUTION_FWD_ALGO_WINOGRAD, // Winograd Transform algo + DOMI_CONVOLUTION_FWD_ALGO_GEMM_ACCU_FLOAT32, // accumulate in L0c with FP32 + DOMI_CONVOLUTION_FWD_ALGO_RESERVED +} domiConvolutionFwdAlgo_t; + +typedef enum tagDomiFullConnectFwdAlgo { + DOMI_FULLCONNECT_FWD_ALGO_HALF = 0, // accumulate in L0c with FP16 + DOMI_FULLCONNECT_FWD_ALGO_FLOAT32 // accumulate in L0c with FP32 +} domiFullConnectFwdAlgo_t; + +typedef enum tagDomiPooingFwdAlgo { + DOMI_POOLING_FWD_ALGO_HALF = 0, // accumulate in L0c with FP16 + DOMI_POOLING_FWD_ALGO_FLOAT32 // accumulate in L0c with FP32 +} domiPooingFwdAlgo_t; + +// mode of convolution +typedef enum tagDomiConvolutionMode { + DOMI_CONV_CONVOLUTION = 0, // math convolution + DOMI_CONV_CROSS_CORRELATION, // cross-correlation convolution + DOMI_CONV_DECONVOLUTION, // deconvolution, also named transposed convolution + DOMI_CONV_MODE_DEPTHWISE, // depthwise convolution + DOMI_CONV_MODE_RESERVED +} domiConvolutionMode_t; + +// softmax mode +typedef enum tagDomiSoftmaxMode { + DOMI_SOFTMAX_MODE_INSTANCE = 0, // compute the softmax over all C, H, W for each N + DOMI_SOFTMAX_MODE_CHANNEL, // compute the softmax over all C for each H, W, N + DOMI_SOFTMAX_MODE_HEIGHT, // compute the softmax over all H for each N, C, W + DOMI_SOFTMAX_MODE_WIDTH, // compute the softmax over all W for each N, C, H + DOMI_SOFTMAX_MODE_RESERVED +} domiSoftmaxMode_t; + +// softmax algorithm +typedef enum tagDomiSoftmaxAlgo { + DOMI_SOFTMAX_FAST = 0, // straightforward implementation + DOMI_SOFTMAX_ACCURATE, // subtract max from every point to avoid overflow + DOMI_SOFTMAX_LOG, // perform the Log softmax operation to avoid overflow + DOMI_SOFTMAX_ACCURATE_FP32, + DOMI_SOFTMAX_RESERVED +} domiSoftmaxAlgo_t; + +// algorithm of convolution backward +typedef enum tagDomiConvolutionBwdAlgo { + DOMI_CONVOLUTION_BWD_ALGO_GEMM = 0, // matrix gemm algo + DOMI_CONVOLUTION_BWD_ALGO_WINOGRAD, // Winograd Transform algo + DOMI_CONVOLUTION_BWD_ALGO_RESERVED +} domiConvolutionBwdAlgo_t; + +// mode of pooling +typedef enum tagDomiPoolingMode { + DOMI_POOLING_MAX = 0, // max pooling + DOMI_POOLING_AVG, // average pooling + DOMI_POOLING_L2, // L2 pooling + DOMI_POOLING_RESERVED +} domiPoolingMode_t; + +// propagate Nan +typedef enum tagDomiNanPropagation { + DOMI_NAN_NOT_PROPAGATE = 0, // Nan numbers are not propagated + DOMI_NAN_PROPAGATE, // Nan numbers are propagated + DOMI_NAN_PROPAGATE_RESERVED +} domiNanPropagation_t; + +// mode of cropandresize +typedef enum tagDomiCropAndResizeMode { + DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear + DOMI_RESIZE_METHOD_NEAREST, // resize nearest + DOMI_RESIZE_RESERVED +} domiCropAndResizeMode_t; + +// yolo version +typedef enum tagDomiYoloVersion { DOMI_YOLO_V2 = 1, DOMI_YOLO_V3, DOMI_YOLO_TRSERVED } domiYoloVersion_t; + +typedef enum tagDomiRNNScopePassType { + DOMI_STATIC_BIDIRECTIONAL_RNN_GENERAL_PASS = 0, + DOMI_DYNAMIC_BIDIRECTIONAL_RNN_GENERAL_PASS, + DOMI_DYNAMIC_BIDIRECTIONAL_RNN_BIDAF_PASS +} domiRNNScopePassType; + +// RNNDataLayout +typedef enum tagDomiRNNDataLayout { + DOMI_RNN_ND_TBX = 0, // data[max_time,batch_size,Xt] + DOMI_RNN_ND_BTX, // data[batch_size,max_time,Xt] + DOMI_RNN_5D_TX1BX, // data[max_time,Xt,1,batch_size,Xt] + DOMI_RNN_5D_BX1TX, // dataa[batch_size,Xt,1,max_time,Xt] + DOMI_RNN_4DTBX1, + DOMI_ENN_DL_RESERVED +} domiRNNDataLayout_t; + +// RNNInputMode +typedef enum tagDomiRNNInputMode { DOMI_RNN_LINEAR_INPUT = 0, DOMI_RNN_SKIP_INPUT } domiRNNInputMode_t; + +// RNNDirectionMode +typedef enum tagDomiRNNDirectionMode { DOMI_RNN_UNIDIRECTIONAL = 0, DOMI_RNN_BIDIRECTIONAL } domiDirectionMode_t; + +typedef enum tagDomiPoolingCeilMode { DOMI_POOLING_FLOOR = 0, DOMI_POOLING_CEIL } domiPoolingCeilMode_t; + +// RNNMode +typedef enum tagDomiRNNActivationMode { + DOMI_RNN_ACTIVATION_SIGMOID = 0, // sigmoid + DOMI_RNN_ACTIVATION_TANH, // tanh + DOMI_RNN_ACTIVATION_RELU, // ReLU + DOMI_RNN_ACTIVATION_RELU1, // ReLU1 + DOMI_RNN_ACTIVATION_RELU6, // ReLU6 + DOMI_RNN_ACTIVATION_RESERVED +} domiRNNActivationMode_t; + +typedef enum tagDomiRNNLSTMOutMode { + DOMI_RNN_LSTM_OUT_SEPARATE = 0, + DOMI_RNN_LSTM_OUT_CONCAT, + DOMI_RNN_LSTM_OUT_RESERVED +} domiRNNLSTMOutPutMode_t; +typedef enum tagDomiRNNLSTMStateOutMode { + DOMI_RNN_LSTM_STATE_OUT_SEPARATE = 0, + DOMI_RNN_LSTM_STATE_OUT_CONCAT_ALL, + DOMI_RNN_LSTM_STATE_OUT_RESERVED +} domiRNNLSTMStateOutMode_t; + +typedef enum tagDomiRNNMode { + DOMI_RNN_RELU = 0, + DOMI_RNN_TANH, + DOMI_LSTM, + DOMI_GRU, + DOMI_RNN_MODE_RESERVED +} domiRNNMode_t; + +typedef enum tagDomiResizeBilinearMode { + DOMI_RESIZE_OUTPUT_DIM_BY_ZOOM_FACTOR = 0, // Output dimension specified by zoom factor + DOMI_RESIZE_OUTPUT_DIM_BY_SHRINK_FACTOR, // specified by shrink factor + DOMI_RESIZE_OUTPUT_DIM_EXPLICIT, // specified explicitly + DOMI_RESIZE_OUTPUT_DIM_RESERVED +} domiResizeOutputDimMode_t; + +#pragma pack(1) // single-byte alignment +// DUMP file struct +struct FileHeader { + int32_t Version; // version + int32_t Output_Offset; // output offset address + char Reserved[24] = {0}; // 24 bytes reserved +}; + +struct BasicInfo { + struct FileHeader header; // file header + int32_t stream_id; // stread id + uint64_t start_time; // start time + uint64_t end_time; // end time + uint32_t input_size; // input memory size + uint32_t output_size; // output memory size + uint32_t weight_size; // weight Memory Size + uint32_t workspace_size; // workspace + uint32_t total_size; // total memory size +}; +#pragma pack() // Cancels single-byte alignment +} // namespace ge + +namespace domi { +/// @brief Data structure definition related to task sinking +enum BuildMode { + GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) + GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) + GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) +}; +} // namespace domi + +#endif // INC_FRAMEWORK_COMMON_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/common/util.h b/metadef/third_party/graphengine/inc/framework/common/util.h new file mode 100644 index 00000000..42ab3868 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/common/util.h @@ -0,0 +1,421 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_COMMON_UTIL_H_ +#define INC_FRAMEWORK_COMMON_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/scope_guard.h" +#include "framework/common/ge_inner_error_codes.h" +#include "mmpa/mmpa_api.h" + +#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ + do { \ + if (size <= 0) { \ + DOMI_LOGE("param[%s] is not a positive number", #size); \ + return PARAM_INVALID; \ + } \ + } while (0) + +#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +// new ge marco +// Encapsulate common resource releases +#define GE_MAKE_GUARD_RTMEM(var) \ + GE_MAKE_GUARD(var, [&] { \ + if (var) GE_CHK_RT(rtFreeHost(var)); \ + }); + +#define GE_MAKE_GUARD_RTSTREAM(var) \ + GE_MAKE_GUARD(var, [&] { \ + if (var) GE_CHK_RT(rtStreamDestroy(var)); \ + }); + +// For propagating errors when calling a function. +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ::ge::Status _status = (expr); \ + if (_status) return _status; \ + } while (0) + +#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ::ge::Status _status = (expr); \ + if (_status) { \ + DOMI_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// check whether the parameter is true. If it is, return FAILED and record the error log +#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + DOMI_LOGE(__VA_ARGS__); \ + return ge::FAILED; \ + } \ + } while (0) + +// Check if the parameter is false. If yes, return FAILED and record the error log +#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + DOMI_LOGE(__VA_ARGS__); \ + return ge::FAILED; \ + } \ + } while (0) + +// Checks whether the parameter is true. If so, returns PARAM_INVALID and records the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + DOMI_LOGE(__VA_ARGS__); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// Check if the parameter is false. If yes, return PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + bool _condition = (condition); \ + if (!_condition) { \ + DOMI_LOGE(__VA_ARGS__); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// Check if the parameter is null. If yes, return PARAM_INVALID and record the error +#define GE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// Check if the parameter is null. If yes, just return and record the error +#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return; \ + } \ + } while (0) + +// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log +#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + exec_expr; \ + } \ + } while (0) + +// Check whether the parameter is null. If yes, return directly and record the error log +#define GE_RT_VOID_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return; \ + } \ + } while (0) + +// Check if the parameter is null. If yes, return false and record the error log +#define GE_RT_FALSE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return false; \ + } \ + } while (0) + +// Check if the parameter is out of bounds +#define GE_CHECK_SIZE(size) \ + do { \ + if (size == 0) { \ + DOMI_LOGE("param[%s] is out of range", #size); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// Check if the container is empty +#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ + do { \ + if (vector.empty()) { \ + DOMI_LOGE("param[%s] is empty!", #vector); \ + return ge::FAILED; \ + } \ + } while (0) + +// Check if the value on the left is greater than or equal to the value on the right +#define GE_CHECK_GE(lhs, rhs) \ + do { \ + if (lhs < rhs) { \ + DOMI_LOGE("param[%s] is less than[%s]", #lhs, #rhs); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +// Check if the value on the left is less than or equal to the value on the right +#define GE_CHECK_LE(lhs, rhs) \ + do { \ + if (lhs > rhs) { \ + DOMI_LOGE("param[%s] is greater than[%s]", #lhs, #rhs); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +#define GE_DELETE_NEW_SINGLE(var) \ + do { \ + if (var != nullptr) { \ + delete var; \ + var = nullptr; \ + } \ + } while (0) + +#define GE_DELETE_NEW_ARRAY(var) \ + do { \ + if (var != nullptr) { \ + delete[] var; \ + var = nullptr; \ + } \ + } while (0) + +/** + * @ingroup domi_common + * @brief version of om.proto file + */ +static constexpr int32_t OM_PROTO_VERSION = 2; + +/** + * Finding an Integer Ceiling Value Without Precision Loss + */ +#define CEIL(N, n) (((N) + (n)-1) / (n)) + +namespace ge { +using google::protobuf::Message; + +/// +/// @ingroup domi_common +/// @brief Maximum file path length +/// +const int32_t DOMI_MAX_PATH_LEN = 256; + +/// +/// @ingroup domi_common +/// @brief proto file in bianary format +/// @param [in] file path of proto file +/// @param [out] proto memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromBinaryFile(const char *file, Message *proto); + +/// +/// @ingroup domi_common +/// @brief Reads the proto structure from an array. +/// @param [in] data proto data to be read +/// @param [in] size proto data size +/// @param [out] proto Memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromArray(const void *data, int size, Message *proto); + +/// +/// @ingroup domi_proto +/// @brief Reads the proto file in the text format. +/// @param [in] file path of proto file +/// @param [out] message Memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromText(const char *file, google::protobuf::Message *message); + +bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); + +/// +/// @ingroup: domi_common +/// @brief: get length of file +/// @param [in] input_file: path of file +/// @return long: File length. If the file length fails to be obtained, the value -1 is returned. +/// +extern long GetFileLength(const std::string &input_file); + +/// +/// @ingroup domi_common +/// @brief Reads all data from a binary file. +/// @param [in] file_name path of file +/// @param [out] buffer Output memory address, which needs to be released by the caller. +/// @param [out] length Output memory size +/// @return false fail +/// @return true success +/// +bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); + +bool ReadBytesFromBinaryFile(const char *file_name, std::vector &buffer); + +/// +/// @ingroup domi_common +/// @brief Recursively Creating a Directory +/// @param [in] directory_path Path, which can be a multi-level directory. +/// @return 0 success +/// @return -1 fail +/// +extern int CreateDirectory(const std::string &directory_path); + +/// +/// @ingroup domi_common +/// @brief Obtains the current time string. +/// @return Time character string in the format : %Y%m%d%H%M%S, eg: 20171011083555 +/// +std::string CurrentTimeInStr(); + +/// +/// @ingroup domi_common +/// @brief onverts Vector of a number to a string. +/// @param [in] v Vector of a number +/// @return string +/// +template +std::string ToString(std::vector &v) { + std::stringstream ss; + ss << "["; + for (T x : v) { + ss << x; + ss << ", "; + } + std::string strRet = + ss.str().substr(0, ss.str().length() - 2); // Delete the two extra characters at the end of the line. + strRet += "]"; + return strRet; +} + +/// +/// @ingroup domi_common +/// @brief Converts RepeatedField to String. +/// @param [in] rpd_field RepeatedField +/// @return string +/// +template +std::string ToString(const google::protobuf::RepeatedField &rpd_field) { + std::stringstream ss; + ss << "["; + for (T x : rpd_field) { + ss << x; + ss << ", "; + } + std::string strRet = + ss.str().substr(0, ss.str().length() - 2); // Delete the two extra characters at the end of the line. + strRet += "]"; + return strRet; +} + +/// +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in microseconds (US) +/// +/// +uint64_t GetCurrentTimestamp(); + +/// +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in seconds (US) +/// +/// +uint32_t GetCurrentSecondTimestap(); + +/// +/// @ingroup domi_common +/// @brief Check whether the product of two int64 numbers exceeds the int64 range. +/// @param [in] a +/// @param [in] b +/// @return false: true: The result is within the normal int64 range. +/// +bool CheckInt64MulOverflow(int64_t a, int64_t b); + +/// +/// @ingroup domi_common +/// @brief Absolute path for obtaining files. +/// @param [in] path of input file +/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned +/// +std::string RealPath(const char *path); + +/// +/// @ingroup domi_common +/// @brief Check whether the specified input file path is valid. +/// 1. The specified path cannot be empty. +/// 2. The path can be converted to an absolute path. +/// 3. The file path exists and is readable. +/// @param [in] file_path path of input file +/// @param [out] result +/// +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// +/// @ingroup domi_common +/// @brief Checks whether the specified output file path is valid. +/// @param [in] file_path path of output file +/// @param [out] result +/// +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// +/// @ingroup domi_common +/// @brief Check whether the file path meets the whitelist verification requirements. +/// @param [in] filePath file path +/// @param [out] result +/// +bool ValidateStr(const std::string &filePath, const std::string &mode); + +/// +/// @ingroup domi_common +/// @brief Check whether the file is normal file. +/// @param [in] file_path file path +/// @param [out] result +/// +bool IsValidFile(const char *file_path); + +/// +/// @ingroup domi_common +/// @brief Check path invalid +/// @param [in] path, path to be checked +/// @param [in] length, length of path +/// @return 0 success +/// @return -1 fail +/// +Status CheckPath(const char *path, size_t length); +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_UTIL_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/omg.h b/metadef/third_party/graphengine/inc/framework/omg/omg.h new file mode 100644 index 00000000..e7ca05f7 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/omg.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_OMG_H_ +#define INC_FRAMEWORK_OMG_OMG_H_ + +#include +#include +#include +#include +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "proto/ge_ir.pb.h" +#include "proto/om.pb.h" + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "runtime/kernel.h" + +using domi::Status; +using std::pair; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +/** + * @ingroup domi_omg + * @brief init omg context + * @return void + */ +Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, + bool is_dynamic_input); + +/** + * @ingroup domi_omg + * @brief generate graph based on the input model file and weight file + * @param [out] graph graph + * @param [in] model_file path of model file + * @param [in] weights_file path of weight file + * @param [in] type type of the input model + * @param [in] op_conf op mapping configuration + * @param [in] target type of platform. If a tiny model is generated, set target to tiny + * @param [in] run_mode run model + * @param [in] enable_l2dynamic enable l2dynamic + * @param [in] is_dynamic_input dynamic input, true of false + * @param [in] atc_params multiply atc params + * @return Status result code + */ +Status ParseGraph(ge::Graph &graph, const std::map &atc_params, const char *model_file, + const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, + const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); + +/** + * @ingroup domi_omg + * @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format + * @param [in] model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +Status ConvertOmModelToJson(const char *model_file, const char *json_file); + +Status ConvertPbtxtToJson(const char *model_file, const char *json_file); +/** + * @ingroup domi_omg + * @brief convert the model file in protobuf format into a JSON file. + * @param [in] framework type of model + * @param [in] om model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); + +void GetGroupName(ge::proto::ModelDef &model); + +void FindParserSo(const string &path, vector &fileList, string &caffe_parser_path); + +Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); + +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); + +Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); + +void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name); + +void UpdateOmgCtxWithParserCtx(); + +void UpdateParserCtxWithOmgCtx(); +} // namespace ge + +namespace domi { +/** + * @ingroup domi_omg + * @brief get omg context + * @return reference of OmgContext + */ +ge::OmgContext &GetContext(); +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_OMG_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/omg_inner_types.h b/metadef/third_party/graphengine/inc/framework/omg/omg_inner_types.h new file mode 100644 index 00000000..454890aa --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/omg_inner_types.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ +#define INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "framework/common/fmk_error_codes.h" +#include "register/register_fmk_types.h" + +using domi::DOMI_TENSOR_ND; +using domi::DOMI_TENSOR_RESERVED; +using domi::domiTensorFormat_t; +using domi::FRAMEWORK_RESERVED; +using domi::FrameworkType; +using std::map; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +/** + * @ingroup domi_omg + * @brief run model + */ +enum RunMode { + GEN_OM_MODEL = 0, // generate offline model file + MODEL_TO_JSON = 1, // convert to JSON file + ONLY_PRE_CHECK = 3, // only for pre-check + PBTXT_TO_JSON = 5 // pbtxt to json +}; + +/// +/// @ingroup domi_omg +/// @brief high-precision mode +/// +enum HighPrecisionMode { + // the FP16 high-precision function is disabled in common mode + HIGH_PRECISION_DEFAULT = 0, + + // high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) + HIGH_PRECISION_FP16 = 1 +}; + +/// +/// @ingroup domi_omg +/// @brief description buffer data +/// +struct OMGBufferData { + void *data; + uint32_t length; +}; + +struct OmgContext { + OmgContext() { format = DOMI_TENSOR_ND; } + domiTensorFormat_t format; + + // format of the input specified by the command line + std::unordered_map input_nodes_format_map; + std::vector output_formats; + + // user-designate input dims + std::vector>> user_input_dims; + // global input dims + std::unordered_map> input_dims; + + // resolve the mapping between operators with the same name and corresponding network. format e.g. + // Detectionoutput:SsdDetectiontOutput + std::map op_conf_map; + // save the output node of the network. key = operator name, value = index, index indicates the output index of the + // operator + std::map> out_nodes_map; + // user-designate out nodes (this is used for determing the orders) + std::vector> user_out_nodes; + // default out nodes (this is used for determing the orders) + std::vector> default_out_nodes; + // save the output node of the network, value = topName, + // topName indicates the output name of the operator. + std::vector user_out_nodes_top_vec; + // net out nodes (where user_out_nodes or leaf nodes) + std::vector net_out_nodes; + // net out nodes top names(only caffe has top) + std::vector out_top_names; + // net data nodes top names(only caffe has top) + std::vector data_top_names; + // preferential format used by the entire network + domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; + domi::FrameworkType type = domi::FRAMEWORK_RESERVED; + RunMode run_mode = ONLY_PRE_CHECK; + bool train_flag = false; + + std::string output_type; + + // Whether to use dynamic batch size or dynamic image size + bool is_dynamic_input = false; + std::string dynamic_batch_size; + std::string dynamic_image_size; + std::string dynamic_dims; + std::string dynamic_node_type; + std::vector> user_real_input_dims; + std::vector cur_dynamic_dims; + bool need_multi_batch = false; +}; +} // namespace ge + +namespace domi { +/** + * @ingroup domi_omg + * @brief get OMG context + * @return OmgContext context + */ +ge::OmgContext &GetContext(); + +struct TEBinInfo { + // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. + // To be compatible with use cases written by previous users, fields are not deleted.(2018.11.21) + std::string bin_file_path; + std::string json_file_path; + std::string ddk_version; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/omg_types.h b/metadef/third_party/graphengine/inc/framework/omg/omg_types.h new file mode 100644 index 00000000..771a53a4 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/omg_types.h @@ -0,0 +1,22 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_OMG_TYPES_H_ +#define INC_FRAMEWORK_OMG_OMG_TYPES_H_ + +#include "register/register_fmk_types.h" + +#endif // INC_FRAMEWORK_OMG_OMG_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/model_parser.h b/metadef/third_party/graphengine/inc/framework/omg/parser/model_parser.h new file mode 100644 index 00000000..20bfcef4 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/model_parser.h @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ + +#include +#include "framework/omg/parser/parser_types.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +using Status = domi::Status; + +namespace domi { +using GetGraphCallback = std::function( + const google::protobuf::Message *root_proto, const std::string &graph)>; +class ModelParser { + public: + ModelParser() {} + + virtual ~ModelParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] file Network model file path + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] proto network model + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, + GetGraphCallback callback, + ge::ComputeGraphPtr &graph) = 0; + /** + * @ingroup domi_omg + * @brief Convert model files to JSON format + * @param [in] model_file Model file path to be converted + * @param [out] json_file Converted JSON file path + * @return SUCCESS + * @return Others failed + */ + virtual Status ToJson(const char *model_file, const char *json_file) { return domi::SUCCESS; } + + /* + * @ingroup domi_omg + * @brief Convert network data type + * @param [in] type Data type to be converted + * @return ge::DataType + */ + virtual ge::DataType ConvertToGeDataType(const uint32_t type) = 0; + + virtual Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) = 0; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/op_parser.h b/metadef/third_party/graphengine/inc/framework/omg/parser/op_parser.h new file mode 100644 index 00000000..087bad32 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/op_parser.h @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ + +#include +#include "framework/omg/parser/parser_types.h" +#include "omg/omg_inner_types.h" +#include "proto/om.pb.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/op_desc_utils.h" + +using google::protobuf::Message; +using Status = domi::Status; + +namespace ge { +/** + * @ingroup domi_omg + * @brief Used to analyze operator information + * + */ +class OpParser { + public: + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~OpParser() {} + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] graph Parsed parameter data + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] Operator parameter data + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseParams(const Message *op_src, ge::Operator &op_dest) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator weight information + * @param [in] op_src Weight data to be resolved + * @param [out] op_dest Weight data after analysis + * @return SUCCESS + * @return FAILED + */ + virtual Status ParseWeights(const Message *op_src, ge::NodePtr &node) = 0; + + /** + * @ingroup domi_omg + * @brief Get the format information according to the parameters in the operator + * @param [in] op_src Parameter data to be resolved + * @param [out] format Output the parsed format + * @return SUCCESS + * @return FAILED + */ + virtual Status GetFormat(const Message *op_src, domi::domiTensorFormat_t &format) { + (void)op_src; + // Indicates that the op does not provide a value for format + format = domi::DOMI_TENSOR_RESERVED; + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/parser_api.h b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_api.h new file mode 100644 index 00000000..382bdfde --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_api.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ + +#include +#include +#include +#include "ge/ge_api_error_codes.h" + +namespace ge { +// Initialize parser +Status ParserInitialize(const std::map& options); +// Finalize parser, release all resources +Status ParserFinalize(); +} // namespace ge +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/parser_factory.h b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_factory.h new file mode 100644 index 00000000..4845606f --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_factory.h @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ + +#include +#include +#include +#include +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_types.h" + +using Status = domi::Status; + +namespace domi { +class WeightsParser; +class ModelParser; + +typedef std::shared_ptr (*MODEL_PARSER_CREATOR_FUN)(void); + +// Create modelparser for different frameworks +class ModelParserFactory { + public: + static ModelParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create a modelparser based on the type entered + * @param [in] type Framework type + * @return Created modelparser + */ + std::shared_ptr CreateModelParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun ModelParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun); + + protected: + ModelParserFactory() {} + ~ModelParserFactory(); + + private: + std::map creator_map_; +}; // end class ModelParserFactory + +class ModelParserRegisterar { + public: + ModelParserRegisterar(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun) { + ModelParserFactory::Instance()->RegisterCreator(type, fun); + } + ~ModelParserRegisterar() {} +}; + +// Registration macros for model parsers +#define REGISTER_MODEL_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Model_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + ModelParserRegisterar g_##type##_Model_Parser_Creator(type, Creator_##type##_Model_Parser) + +typedef std::shared_ptr (*WEIGHTS_PARSER_CREATOR_FUN)(void); + +// Create weightsparser for different frameworks +class WeightsParserFactory { + public: + static WeightsParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create weightsparser based on the type entered + * @param [in] type Framework type + * @return Created weightsparser + */ + std::shared_ptr CreateWeightsParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun WeightsParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun); + + protected: + WeightsParserFactory() {} + ~WeightsParserFactory(); + + private: + std::map creator_map_; +}; // end class WeightsParserFactory + +class WeightsParserRegisterar { + public: + WeightsParserRegisterar(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun) { + WeightsParserFactory::Instance()->RegisterCreator(type, fun); + } + ~WeightsParserRegisterar() {} +}; + +// Register macro of weight resolver +#define REGISTER_WEIGHTS_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Weights_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + WeightsParserRegisterar g_##type##_Weights_Parser_Creator(type, Creator_##type##_Weights_Parser) +}; // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/parser_inner_ctx.h b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_inner_ctx.h new file mode 100644 index 00000000..f28f2c30 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_inner_ctx.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include + +#include "external/register/register_fmk_types.h" +#include "external/register/register_types.h" +#include "framework/omg/omg_inner_types.h" + +namespace ge +{ +struct ParserContext +{ + // format of the input specified by the command line + std::unordered_map input_nodes_format_map; + std::vector output_formats; + // user-designate input dims + std::vector>> user_input_dims; + std::unordered_map> input_dims; + // resolve the mapping between operators with the same name and corresponding network. format e.g. + // Detectionoutput:SsdDetectiontOutput + std::map op_conf_map; + // user-designate out nodes (this is used for determing the orders) + std::vector> user_out_nodes; + // default out nodes (this is used for determing the orders) + std::vector> default_out_nodes; + // save the output node of the network. key = operator name, value = index, index indicates the output index of the + // operator + std::map> out_nodes_map; + // save the output node of the network, value = topName, + // topName indicates the output name of the operator. + std::vector user_out_nodes_top_vec; + // net out nodes (where user_out_nodes or leaf nodes) + std::vector net_out_nodes; + // net data nodes top names(only caffe has top) + std::vector data_top_names; + // net out nodes top names(only caffe has top) + std::vector out_top_names; + // Whether to use dynamic batch size or dynamic image size + bool is_dynamic_input = false; + bool train_flag = false; + domi::domiTensorFormat_t format = domi::DOMI_TENSOR_ND; + domi::FrameworkType type = domi::FRAMEWORK_RESERVED; + RunMode run_mode = GEN_OM_MODEL; + // save caffe custom proto path, used by caffe parse + std::string custom_proto_path; + // save caffe proto path, used by caffe parse + std::string caffe_proto_path; + // name of the pass that needs to take effect + std::string enable_scope_fusion_passes; +}; + +ParserContext &GetParserContext(); +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/parser_types.h b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_types.h new file mode 100644 index 00000000..62c9c750 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/parser_types.h @@ -0,0 +1,508 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_TYPES_H_ +#define PARSER_COMMON_TYPES_H_ + +#include +#include + +#include "register/register_types.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#ifndef DOMI_DYNAMIC_CAST +#define DOMI_DYNAMIC_CAST static_cast +#endif +#ifndef DOMI_DYNAMIC_POINTER_CAST +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif +#else +#ifndef DOMI_DYNAMIC_CAST +#define DOMI_DYNAMIC_CAST static_cast +#endif +#ifndef DOMI_DYNAMIC_POINTER_CAST +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif +#endif + +namespace ge { +namespace parser { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *AIPPDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CORRELATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CORRELATIONV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DECONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *POOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ELTWISE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RELU6; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SIGMOID; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ABSVAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TANH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PRELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FUSIONBATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SCALE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FULL_CONNECTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SOFTMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PLUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACTIVATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FLATTEN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MATMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RSQRT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BIASADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RESHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFORMAT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPCONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DROPOUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DROPOUTGENMASK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DROPOUTDOMASK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCAT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ROIPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PROPOSAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FSRDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DETECTIONPOSTPROCESS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LRN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TRANSDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PERMUTE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDNORMALIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDPRIORBOX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NETOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFINEDETDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CHANNELAXPY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PSROIPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *POWER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *POW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ROIALIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PYTHON; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FREESPACEEXTRACT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPATIALTF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SHAPEN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GATHERND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REALDIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PACK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SLICE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SLICED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FLOORDIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQUEEZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *UNSQUEEZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STRIDEDSLICE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RANGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RPNPROPOSALS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DECODEBBOX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PADV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MIRRORPAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CLIPBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTRCNNPREDICTIONS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPLIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPLITV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EXPANDDIMS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EMPTY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GREATER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SWITCHN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SYMBOLICGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REMOTECALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *_IF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STATELESSIF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *IF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CASE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *_WHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *WHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STATELESSWHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PARTITIONEDCALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STATEFULPARTITIONEDCALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FAKEPARAM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TRANSPOSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TRANSPOSED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REGION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *YOLO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *YOLODETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FILL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REVERSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *UNPACK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *YOLO2REORG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCESUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONSTANT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RESIZEBILINEAR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RESIZEBILINEARGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MAXIMUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FRAMEWORKOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ARG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FUSEDBATCHNORMGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LSTM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HIGHWAY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RNN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATTENTIONDECODER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LOGICAL_NOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LOGICAL_AND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LOGICAL_OR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NOTEQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *INTERP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SHUFFLECHANNEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *AIPP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MULTISHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RECIPROCAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACOSH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASINH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MINIMUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CLIP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *L2NORMALIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CROPANDRESIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *UNUSEDCONST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPARSETODENSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NONMAXSUPPRESSION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TOPKV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *INVERTPERMUTATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MULTINOMIAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REVERSESEQUENCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCEPROD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCEMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCEMIN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EXTRACTIMAGEPATCHES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQRT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCEALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RESIZENEARESTNEIGHBOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPACETOBATCHND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BATCHTOSPACEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSERT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GREATEREQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FLOOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RANDOMUNIFORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BATCHMATMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPACETODEPTH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHTOSPACE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RINT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATAN2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATANH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACOS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASIN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NEG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LOG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ROUND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *UPSAMPLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FLOORMOD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LESS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LESSEQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ONEHOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFMERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ENTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFENTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LOOPCOND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NEXTITERATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFNEXTITERATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EXIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFEXIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONTROLTRIGGER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ZEROSLIKE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EXP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *WHERE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FAKEQUANTWITHMINMAXVARS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SOFTPLUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SOFTSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *COSH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SINH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQUAREDDIFFERENCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char + *REQUIREDSPACETOBATCHPADDINGS; // for retinanet scope fusion +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINAMULTIANCHORS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETCLIPPEDBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETFILTEREDDETECTIONS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETANCHORS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNMAP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNMAP1; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNROIINTERPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNGRIDANCHORGENERATOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ROIINTERPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FASTERRCNNCLIPTOWINDOW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *EMBEDLOOKUP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HASHLOOKUP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LSH_PROJ; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SVDF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDANCHORGENERATOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *IDENTITY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *IDENTITYN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PLACEHOLDERWITHDEFAULT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SELECT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GETSPAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STOPGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PREVENTGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GUARANTEECONST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BROADCASTGRADIENTARGS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BROADCASTARGS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONFUSIONMATRIX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RANK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PLACEHOLDER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *END; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BASICLSTMCELL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GETNEXT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *INITDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REFIDENTITY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BITCAST; + +/***************Ann special operator*************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_MEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_CONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_DEPCONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_FULLCONNECTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_NETOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_DATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_RESHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_ADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_MUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_SUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_DIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_DEQUANTIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_QUANTIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_PAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANN_RESIZE_BILINEAR; + +/***************************************************/ +/******************Training operator*************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *GATHERV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONVGRADFILTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONV2D; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONV2DBACKPROPINPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FUSEDBATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BIASADDGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACTIVATIONGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MAXPOOLWITHARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MAXPOOLGRADWITHARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SNAPSHOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VAR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MEANGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TRANSLATE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ADDN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *L2LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MULTIPLY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HUBERLOSSGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HUBERLOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NEGATIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPARSESOFTMAXCROSSENTROPY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SPARSESOFTMAXCROSSENTROPYGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDSQUEEZEFUSION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCATFOUR2FIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCATFIVE2FOUR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDREALDIVTILEMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDSUMMULREALDIVMEAN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARIABLEV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARHANDLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TEMPORARYVARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DESTROYTEMPORARYVARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGNVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGNADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGNADDVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGNSUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASSIGNSUBVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYMOMENTUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RESOURCEAPPLYMOMENTUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SGD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *NOOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *READVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *PARALLELCONCATSTART; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONSTANTOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHWISECONV2DBACKPROPFILTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHWISECONV2DBACKPORPINPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHWISECONV2DFORWARDNATIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DROPOUTGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYRMSPROPMIXEDPRECISION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYRMSPROP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RELU6GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *AVGPOOLGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCATV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCATOFFSET; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LAYERNORMGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LAYERNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LARS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DYNAMICSTITCH; + +/***************************************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQUARE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMBROADCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMALLGATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMALLREDUCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREDUCESCATTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMSEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMRECEIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEWRITE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARISINITIALIZEDOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LogTimeStamp; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ISVARIABLEINITIALIZED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STREAMSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STREAMSWITCHN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STREAMACTIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MEMCPYASYNC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *MEMCPYADDRASYNC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *STREAMMERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ENDGRAPH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RECV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ENDOFSEQUENCE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LABELSET; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LABELGOTO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LABELGOTOEX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LABELSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *LABELSWITCHBYINDEX; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATOMICADDRCLEAN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ABS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACCUMULATE_N_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACOS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ACOSH_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ANY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPROXIMATE_EQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASIN_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ASINH_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ATAN_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BROADCAST_TO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ELU_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ADD_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DATAFORMATDIMMAP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DATAFORMATVECPERMUTE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BESSELI0E; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *BESSELI1E; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADADELTA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADAGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADAGRADDA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADAM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADAMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYADDSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYCENTEREDRMSPROP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYFTRL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYFTRLV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYGRADIENTDESCENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYPOWERSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYPROXIMALADAGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *APPLYPROXIMALGRADIENTDESCENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEQUANTIZE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FOCAL_LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *FOCAL_LOSS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SMOOTHL1_LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SMOOTHL1_LOSS_grad; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REDUCEMEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CONCAT_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *ONEHOT_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SLICE_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TILE_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SUM_V2; +// Common type when the operator has the same name +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DETECTIONOUTPUT; +// Custom operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CUSTOMOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CUSTOMOP_NCHW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CUSTOMOP_NHWC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *CUSTOMOP_NC1HWC0; + +// Depthwise 4d_2_6d,6d_2_4d +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHWISEWEIGHT4D26D; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *DEPTHWISEWEIGHT6D24D; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQRTGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SIGMOIDGRAD; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *TRANSSHAPE; + +// Horovod operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HVDCALLBACKALLREDUCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HVDCALLBACKALLGATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HVDCALLBACKBROADCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HVDWAIT; + +/// +/// @brief Magic number of model file +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_MAGIC_NUM; // magic number + +/// +/// @brief Model head length +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_HEAD_LEN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_VERSION; ///< Model version 1.0/// + +// alpha default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float ALPHA_DEFAULT_VALUE; + +// beta default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float BETA_DEFAULT_VALUE; + +/// +/// @ingroup domi_omg +/// @brief INPUT node type +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string INPUT_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMMY_DATA; + +// dim default size value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY static const int32_t DIM_DEFAULT_SIZE = 4; + +// for fusion op plugin +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DATA_TYPE; + +// framework Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string FRAMEWORK_OP_TYPE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_NET_OUTPUT; + +#pragma pack() // Cancels single-byte alignment +} // namespace parser +} // namespace ge + +#endif // PARSER_COMMON_TYPES_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/parser/weights_parser.h b/metadef/third_party/graphengine/inc/framework/omg/parser/weights_parser.h new file mode 100644 index 00000000..1b5216b3 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/parser/weights_parser.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ + +#include "graph/graph.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +namespace domi { +/** + * @ingroup domi_omg + * @brief Weight information resolver + * + */ +class WeightsParser { + public: + /** + * @ingroup domi_omg + * @brief Constructor + */ + WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze weight data + * @param [in] file Path of weight file after training + * @param [in|out] graph Graph for saving weight information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *input, uint32_t lengt, ge::ComputeGraphPtr &graph) = 0; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ diff --git a/metadef/third_party/graphengine/inc/framework/omg/version.h b/metadef/third_party/graphengine/inc/framework/omg/version.h new file mode 100644 index 00000000..ac649d83 --- /dev/null +++ b/metadef/third_party/graphengine/inc/framework/omg/version.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_OMG_VERSION_H_ +#define INC_FRAMEWORK_OMG_VERSION_H_ + +#include +#include +#include +#include + +#include "common/debug/log.h" +#include "common/string_util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +class PlatformVersionManager { + public: + PlatformVersionManager() = delete; + ~PlatformVersionManager() = delete; + static Status GetPlatformVersion(std::string &ver) { + ver = "1.11.z"; + std::vector version_splits = StringUtils::Split(ver, '.'); + GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); + + GELOGI("Read current platform version: %s.", ver.c_str()); + return SUCCESS; + } +}; // class PlatformManager +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_VERSION_H_ diff --git a/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch b/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch new file mode 100644 index 00000000..0fcf50c4 --- /dev/null +++ b/metadef/third_party/patch/securec/0001-add-securec-cmake-script.patch @@ -0,0 +1,105 @@ +From 455c9812d70646fe725896d597d6c953bf5a09ac Mon Sep 17 00:00:00 2001 +From: taoxiangdong +Date: Wed, 14 Oct 2020 22:14:01 +0800 +Subject: [PATCH] add securec cmake script + +--- + CMakeLists.txt | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++ + 1 file changed, 86 insertions(+) + create mode 100755 CMakeLists.txt + +diff --git a/CMakeLists.txt b/CMakeLists.txt +new file mode 100755 +index 0000000..9b91fb2 +--- /dev/null ++++ b/CMakeLists.txt +@@ -0,0 +1,86 @@ ++cmake_minimum_required(VERSION 3.14) ++project(Securec) ++file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} ++ "src/vsprintf_s.c" ++ "src/wmemmove_s.c" ++ "src/strncat_s.c" ++ "src/vsnprintf_s.c" ++ "src/fwscanf_s.c" ++ "src/scanf_s.c" ++ "src/strcat_s.c" ++ "src/sscanf_s.c" ++ "src/secureprintoutput_w.c" ++ "src/wmemcpy_s.c" ++ "src/wcsncat_s.c" ++ "src/secureprintoutput_a.c" ++ "src/secureinput_w.c" ++ "src/memcpy_s.c" ++ "src/fscanf_s.c" ++ "src/vswscanf_s.c" ++ "src/secureinput_a.c" ++ "src/sprintf_s.c" ++ "src/memmove_s.c" ++ "src/swscanf_s.c" ++ "src/snprintf_s.c" ++ "src/vscanf_s.c" ++ "src/vswprintf_s.c" ++ "src/wcscpy_s.c" ++ "src/vfwscanf_s.c" ++ "src/memset_s.c" ++ "src/wscanf_s.c" ++ "src/vwscanf_s.c" ++ "src/strtok_s.c" ++ "src/wcsncpy_s.c" ++ "src/vfscanf_s.c" ++ "src/vsscanf_s.c" ++ "src/wcstok_s.c" ++ "src/securecutil.c" ++ "src/gets_s.c" ++ "src/swprintf_s.c" ++ "src/strcpy_s.c" ++ "src/wcscat_s.c" ++ "src/strncpy_s.c" ++ ) ++ ++include_directories(./include) ++include_directories(./src) ++add_library(shared_c_sec SHARED ${SRC_LIST}) ++ ++target_compile_options(shared_c_sec PRIVATE ++ -I/usr/local/include ++ -Werror ++ -Wall ++ -O1 ++) ++target_compile_definitions(shared_c_sec PRIVATE ++ NDEBUG ++ SECUREC_SUPPORT_STRTOLD=1 ++ ) ++ ++add_library(static_c_sec STATIC ${SRC_LIST}) ++ ++target_compile_options(static_c_sec PRIVATE ++ -I/usr/local/include ++ -Werror ++ -Wall ++ -O1 ++) ++ ++target_compile_definitions(static_c_sec PRIVATE ++ NDEBUG ++ SECUREC_SUPPORT_STRTOLD=1 ++ ) ++ ++set_target_properties(static_c_sec ++ PROPERTIES ++ OUTPUT_NAME c_sec ++) ++set_target_properties(shared_c_sec ++ PROPERTIES ++ OUTPUT_NAME c_sec ++) ++install(TARGETS shared_c_sec static_c_sec OPTIONAL ++ DESTINATION lib) ++install(FILES "./include/securec.h" ++ "./include/securectype.h" ++ DESTINATION include) +-- +2.17.1 + diff --git a/metadef/third_party/patch/securec/securec.patch001 b/metadef/third_party/patch/securec/securec.patch001 new file mode 100644 index 00000000..01c2d769 --- /dev/null +++ b/metadef/third_party/patch/securec/securec.patch001 @@ -0,0 +1,22 @@ +diff -Npur -x .git libboundscheck/CMakeLists.txt securec/CMakeLists.txt +--- libboundscheck/CMakeLists.txt 1970-01-01 08:00:00.000000000 +0800 ++++ securec/CMakeLists.txt 2020-09-19 16:53:48.689460700 +0800 +@@ -0,0 +1,18 @@ ++cmake_minimum_required(VERSION 3.14) ++project(Securec) ++set(CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -fPIC -O0 -Wall -Wno-deprecated-declarations -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)='") ++set(CMAKE_C_FLAGS_RELEASE "$ENV{CFLAGS} -fPIC -Wall -D_FORTIFY_SOURCE=2 -O2 -Wno-deprecated-declarations -fstack-protector-all -Wl,-z,relro,-z,now") ++set(CMAKE_EXPORT_COMPILE_COMMANDS ON) ++ ++#add flags ++set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -I/usr/local/include -Werror") ++ ++include_directories(./include) ++aux_source_directory(./src SECUREC_SRCS) ++add_library(c_sec SHARED ${SECUREC_SRCS}) ++ ++install(TARGETS c_sec ++ DESTINATION lib) ++install(FILES "./include/securec.h" ++ "./include/securectype.h" ++ DESTINATION include) diff --git a/metadef/third_party/transformer/inc/axis_util.h b/metadef/third_party/transformer/inc/axis_util.h new file mode 100644 index 00000000..81dab321 --- /dev/null +++ b/metadef/third_party/transformer/inc/axis_util.h @@ -0,0 +1,180 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file axis_util.h + * \brief get the axis value + */ +#ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ +#define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ + +#include +#include +#include + +#include "external/graph/ge_error_codes.h" +#include "external/graph/types.h" +#include "framework/common/debug/ge_log.h" + +namespace common { +namespace transformer { + +const int32_t DIM_DEFAULT_SIZE = 4; +const uint32_t NCHW_DIMENSION_NUM = 4; + +const int32_t AXIS_NCHW_DIM_N = 0; +const int32_t AXIS_NCHW_DIM_C = 1; +const int32_t AXIS_NCHW_DIM_H = 2; +const int32_t AXIS_NCHW_DIM_W = 3; + +const int32_t AXIS_NHWC_DIM_N = 0; +const int32_t AXIS_NHWC_DIM_H = 1; +const int32_t AXIS_NHWC_DIM_W = 2; +const int32_t AXIS_NHWC_DIM_C = 3; + +const int32_t AXIS_NC1HWC0_DIM_N = 0; +const int32_t AXIS_NC1HWC0_DIM_C1 = 1; +const int32_t AXIS_NC1HWC0_DIM_C0 = 4; +const int32_t AXIS_NC1HWC0_DIM_H = 2; +const int32_t AXIS_NC1HWC0_DIM_W = 3; + +const int32_t AXIS_HWCN_DIM_H = 0; +const int32_t AXIS_HWCN_DIM_W = 1; +const int32_t AXIS_HWCN_DIM_C = 2; +const int32_t AXIS_HWCN_DIM_N = 3; + +const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; +const int32_t AXIS_C1HWNCoC0_DIM_H = 1; +const int32_t AXIS_C1HWNCoC0_DIM_W = 2; +const int32_t AXIS_C1HWNCoC0_DIM_N = 3; +const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; +const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; + +const int32_t NDHWC_DIM_N = 0; +const int32_t NDHWC_DIM_D = 1; +const int32_t NDHWC_DIM_H = 2; +const int32_t NDHWC_DIM_W = 3; +const int32_t NDHWC_DIM_C = 4; + +const int32_t NCDHW_DIM_N = 0; +const int32_t NCDHW_DIM_C = 1; +const int32_t NCDHW_DIM_D = 2; +const int32_t NCDHW_DIM_H = 3; +const int32_t NCDHW_DIM_W = 4; + +const int32_t DHWCN_DIM_D = 0; +const int32_t DHWCN_DIM_H = 1; +const int32_t DHWCN_DIM_W = 2; +const int32_t DHWCN_DIM_C = 3; +const int32_t DHWCN_DIM_N = 4; + +const int32_t DHWNC_DIM_D = 0; +const int32_t DHWNC_DIM_H = 1; +const int32_t DHWNC_DIM_W = 2; +const int32_t DHWNC_DIM_N = 3; +const int32_t DHWNC_DIM_C = 4; + + +#define CHECK_NOTNULL(val) \ + do { \ + if ((val) == nullptr) { \ + GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \ + return false; \ + } \ + } while (0) + +#define CHECK(cond, log_func, return_expr) \ + do { \ + if (cond) { \ + log_func; \ + return_expr; \ + } \ + } while (0) + +enum AxisValueType { + AXIS_N = 0, + AXIS_C = 1, + AXIS_H = 2, + AXIS_W = 3, + AXIS_C1 = 4, + AXIS_C0 = 5, + AXIS_Co = 6, + AXIS_D = 7, + AXIS_BOTTOM = 8 +}; + +int64_t DivisionCeiling(int64_t dividend, int64_t divisor); + +/* Axis value is arranged as {N,C,H,W,C1,C0,...} */ +/* The first parameter is old shape's dimension, + * second is c0 and third is axis value. */ +using GetAxisValueInfoByFormat = + std::function&, const uint32_t&, std::vector&, std::vector&)>; + +using GetAxisValueInfoByFormatPtr = std::shared_ptr; + +class AxisUtil { + public: + AxisUtil(); + ~AxisUtil(){}; + bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector& dimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + bool HasAxisValueFunc(const ge::Format& format); + + private: + static bool CheckParams(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNCHW(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNHWC(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNC1HWC0(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByFz(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByHWCN(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByND(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByC1HWNCoC0(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNDHWC(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value); + + static bool GetAxisValueByNCDHW(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value); + + static bool GetAxisValueByDHWCN(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value); + + static bool GetAxisValueByDHWNC(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value); + /* map of GetAxisValueInfoByFormat, get axis value by different original + * formats. */ + std::map getAxisValueFuncMap; +}; +} // namespace transformer +} // namespace common + +#endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ diff --git a/metadef/third_party/transformer/inc/transfer_shape_according_to_format.h b/metadef/third_party/transformer/inc/transfer_shape_according_to_format.h new file mode 100644 index 00000000..0e5657de --- /dev/null +++ b/metadef/third_party/transformer/inc/transfer_shape_according_to_format.h @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file transfer_shape_according_to_format.h + * \brief set shape according to original format and current format + */ +#ifndef COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ +#define COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ + +#include "transformer/inc/axis_util.h" + +#include +#include +#include + +#include "graph/types.h" +#include "graph/utils/op_desc_utils.h" + +namespace common { +namespace transformer { + +enum OpImplType { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_RESERVED // reserved value +}; + +const uint32_t SHAPE_NUMBER_16 = 16; +const uint32_t SHAPE_NUMBER_32 = 32; +const uint32_t SHAPE_DIM_VALUE_C04 = 4; +const uint32_t NI = 16; +const uint32_t MINUS_VALUE_ONE = 1; +const uint32_t MINUS_VALUE_TWO = 2; +const uint32_t SIZE_OF_CN = 2; +const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; + +/* The first parameter is axis value, second is new shape and third is + * op implementation type. */ +using GetNewShapeByAxisValueAndFormat = + std::function&, const int64_t&, vector&, vector&)>; + +using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr; + +struct ShapeAndFormatInfo { + const std::vector &oldShape; + std::vector &newShape; + const ge::Format &oldFormat; + const ge::Format &newFormat; + const ge::DataType ¤tDataType; + const int64_t &opImplType; +}; + +using ShapeAndFormat = struct ShapeAndFormatInfo; + +class ShapeTransferAccordingToFormat { + public: + ShapeTransferAccordingToFormat(); + + ~ShapeTransferAccordingToFormat(){}; + + ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; + + ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat&) = delete; + + bool GetShapeAccordingToFormat(ShapeAndFormat &inputAndOutputInfo, int64_t* c = nullptr); + + /* ----------Below is the function of getting new shape---------------------- */ + static bool GetNDC1HWC0ShapeByAxisValue(vector &new_shape, const int64_t &impl_type, + const std::vector &axis_value, const vector &nd_value); + + static bool GetNCHWShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetNHWCShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetNC1HWC0ShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetFzShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetHWCNShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetC1HWNCoC0ShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + static bool GetNzShapeByAxisValue(vector &newShape, const int64_t &implType, + const vector &axisValue, const vector &ndValue); + + private: + /* map of GetAxisValueInfoByFormat, get axis value by different original + * formats. */ + std::map getNewShapeFuncMap; + std::map mapOfDtypeAndC0; +}; +} // namespace transformer +} // namespace common + +#endif // COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ diff --git a/metadef/third_party/transformer/src/axis_util.cpp b/metadef/third_party/transformer/src/axis_util.cpp new file mode 100644 index 00000000..706b12e2 --- /dev/null +++ b/metadef/third_party/transformer/src/axis_util.cpp @@ -0,0 +1,289 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file axis_util.cpp + * \brief get the axis value + */ +#include "transformer/inc/axis_util.h" +#include "graph/types.h" + +namespace common { +namespace transformer { +using namespace ge; +using namespace std; + +AxisUtil::AxisUtil() { + getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared(GetAxisValueByNCHW)}, + {FORMAT_NHWC, std::make_shared(GetAxisValueByNHWC)}, + {FORMAT_NC1HWC0, std::make_shared(GetAxisValueByNC1HWC0)}, + {FORMAT_HWCN, std::make_shared(GetAxisValueByHWCN)}, + {FORMAT_ND, std::make_shared(GetAxisValueByND)}, + {FORMAT_C1HWNCoC0, std::make_shared(GetAxisValueByC1HWNCoC0)}, + {FORMAT_NDHWC, std::make_shared(GetAxisValueByNDHWC)}, + {FORMAT_NCDHW, std::make_shared(GetAxisValueByNCDHW)}, + {FORMAT_DHWCN, std::make_shared(GetAxisValueByDHWCN)}, + {FORMAT_DHWNC, std::make_shared(GetAxisValueByDHWNC)}}; +} + +int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { + if (divisor == 0) { + return 0; + } else { + return (dividend + divisor - 1) / divisor; + } +} + +bool AxisUtil::GetAxisValueByOriginFormat(const Format &format, const vector &dimVec, const uint32_t &c0, + vector &axisValue, vector &ndValue) { + auto iterGetAxisFunc = getAxisValueFuncMap.find(format); + if (iterGetAxisFunc == getAxisValueFuncMap.end()) { + GELOGI("Can not get axis value of old format %u!", format); + return false; + } + GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; + CHECK_NOTNULL(getAxisFunc); + return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); +} + +bool AxisUtil::HasAxisValueFunc(const Format &format) { + auto iterGetAxisFunc = getAxisValueFuncMap.find(format); + if (iterGetAxisFunc == getAxisValueFuncMap.end()) { + GELOGI("Can not get axis value of format %u!", format); + return false; + } + return true; +} + +bool AxisUtil::CheckParams(const vector &originalDimVec, const uint32_t &c0, vector &axisValue, + vector &ndValue) { + ndValue = originalDimVec; + auto dimSize = originalDimVec.size(); + if (dimSize < DIM_DEFAULT_SIZE) { + /* Before this funcion, we should call function PadDimensionTo4. */ + GELOGI("Dimension size %zu is invalid.", dimSize); + return false; + } + if (c0 == 0) { + GELOGE(GRAPH_FAILED, "[ERROR]c0 is zero!"); + return false; + } + + return true; +} + +bool AxisUtil::GetAxisValueByND(const vector &originalDimVec, const uint32_t &c0, vector &axisValue, + vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + ndValue = originalDimVec; + /* To differentiate the input datatype of int8 and others */ + axisValue[AXIS_C0] = c0; + if (originalDimVec.size() == NCHW_DIMENSION_NUM) { + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + } + return true; +} + +bool AxisUtil::GetAxisValueByNCHW(const vector &originalDimVec, const uint32_t &c0, vector &axisValue, + vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NCHW to NZ */ + axisValue[AXIS_C0] = c0; + // TODO: temporarily modified to warning level.If modified normally, it needs complementary dimension for origin shape + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGW("[WARNING]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByNHWC(const vector &originalDimVec, const uint32_t &c0, vector &axisValue, + vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + // TODO: temporarily modified to warning level.If modified normally, it needs complementary dimension for origin shape + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGW("[WARNING]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByNC1HWC0(const vector &originalDimVec, const uint32_t &c0, + vector &axisValue, vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"), + return false); + + auto dimSize = originalDimVec.size(); + if (dimSize == DIM_DEFAULT_SIZE + 1) { + axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; + axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; + axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; + } else { + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_C0] = c0; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + } + + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + return true; +} + +bool AxisUtil::GetAxisValueByHWCN(const vector &originalDimVec, const uint32_t &c0, vector &axisValue, + vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + // TODO: temporarily modified to warning level. If modified normally, it needs complementary dimension for origin shape + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGW("[WARNING]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector &originalDimVec, const uint32_t &c0, + vector &axisValue, vector &ndValue) { + CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; + axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; + axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; + axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; + return true; +} + +bool AxisUtil::GetAxisValueByNDHWC(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value) { + CHECK(axis_value.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(original_dim_vec.empty(), GELOGI("Original dim vector is empty!"), return true); + + axis_value[AXIS_C0] = c0; + nd_value = original_dim_vec; + + axis_value[AXIS_N] = original_dim_vec[NDHWC_DIM_N]; + int64_t axis_c_val = original_dim_vec[NDHWC_DIM_C]; + + axis_value[AXIS_C] = axis_c_val; + axis_value[AXIS_H] = original_dim_vec[NDHWC_DIM_H]; + axis_value[AXIS_W] = original_dim_vec[NDHWC_DIM_W]; + axis_value[AXIS_C1] = DivisionCeiling(axis_c_val, c0); + axis_value[AXIS_C0] = c0; + axis_value[AXIS_Co] = c0; + axis_value[AXIS_D] = original_dim_vec[NDHWC_DIM_D]; + return true; +} + +bool AxisUtil::GetAxisValueByNCDHW(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value) { + CHECK(axis_value.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(original_dim_vec.empty(), GELOGI("Original dim vector is empty!"), return true); + + axis_value[AXIS_C0] = c0; + nd_value = original_dim_vec; + + axis_value[AXIS_N] = original_dim_vec[NCDHW_DIM_N]; + int64_t axis_c_val = original_dim_vec[NCDHW_DIM_C]; + + axis_value[AXIS_C] = axis_c_val; + axis_value[AXIS_H] = original_dim_vec[NCDHW_DIM_H]; + axis_value[AXIS_W] = original_dim_vec[NCDHW_DIM_W]; + axis_value[AXIS_C1] = DivisionCeiling(axis_c_val, c0); + axis_value[AXIS_C0] = c0; + axis_value[AXIS_Co] = c0; + axis_value[AXIS_D] = original_dim_vec[NCDHW_DIM_D]; + return true; +} + +bool AxisUtil::GetAxisValueByDHWCN(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value) { + CHECK(axis_value.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(original_dim_vec.empty(), GELOGI("Original dim vector is empty!"), return true); + + axis_value[AXIS_C0] = c0; + nd_value = original_dim_vec; + + axis_value[AXIS_N] = original_dim_vec[DHWCN_DIM_N]; + int64_t axis_c_val = original_dim_vec[DHWCN_DIM_C]; + + axis_value[AXIS_C] = axis_c_val; + axis_value[AXIS_H] = original_dim_vec[DHWCN_DIM_H]; + axis_value[AXIS_W] = original_dim_vec[DHWCN_DIM_W]; + axis_value[AXIS_C1] = DivisionCeiling(axis_c_val, c0); + axis_value[AXIS_C0] = c0; + axis_value[AXIS_Co] = c0; + axis_value[AXIS_D] = original_dim_vec[DHWCN_DIM_D]; + return true; +} + +bool AxisUtil::GetAxisValueByDHWNC(const std::vector& original_dim_vec, const uint32_t& c0, + std::vector& axis_value, std::vector& nd_value) { + CHECK(axis_value.empty(), GELOGI("AxisValue is empty!"), return true); + CHECK(original_dim_vec.empty(), GELOGI("Original dim vector is empty!"), return true); + + axis_value[AXIS_C0] = c0; + nd_value = original_dim_vec; + + axis_value[AXIS_N] = original_dim_vec[DHWNC_DIM_N]; + int64_t axis_c_val = original_dim_vec[DHWNC_DIM_C]; + + axis_value[AXIS_C] = axis_c_val; + axis_value[AXIS_H] = original_dim_vec[DHWNC_DIM_H]; + axis_value[AXIS_W] = original_dim_vec[DHWNC_DIM_W]; + axis_value[AXIS_C1] = DivisionCeiling(axis_c_val, c0); + axis_value[AXIS_C0] = c0; + axis_value[AXIS_Co] = c0; + axis_value[AXIS_D] = original_dim_vec[DHWNC_DIM_D]; + return true; +} +} // namespace transformer +} // namespace common diff --git a/metadef/third_party/transformer/src/transfer_shape_according_to_format.cpp b/metadef/third_party/transformer/src/transfer_shape_according_to_format.cpp new file mode 100644 index 00000000..0c3b8cee --- /dev/null +++ b/metadef/third_party/transformer/src/transfer_shape_according_to_format.cpp @@ -0,0 +1,255 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file transfer_shape_according_to_format.cpp + * \brief set shape according to original format and current format + */ +#include "transformer/inc/transfer_shape_according_to_format.h" + +namespace common { +namespace transformer { +using namespace ge; +using namespace std; + +ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { + getNewShapeFuncMap = { + {ge::FORMAT_NCHW, std::make_shared(GetNCHWShapeByAxisValue)}, + {ge::FORMAT_NHWC, std::make_shared(GetNHWCShapeByAxisValue)}, + {ge::FORMAT_NC1HWC0, std::make_shared(GetNC1HWC0ShapeByAxisValue)}, + {ge::FORMAT_NDC1HWC0, std::make_shared(GetNDC1HWC0ShapeByAxisValue)}, + {ge::FORMAT_FRACTAL_Z, std::make_shared(GetFzShapeByAxisValue)}, + {ge::FORMAT_HWCN, std::make_shared(GetHWCNShapeByAxisValue)}, + {ge::FORMAT_C1HWNCoC0, std::make_shared(GetC1HWNCoC0ShapeByAxisValue)}, + {ge::FORMAT_FRACTAL_NZ, std::make_shared(GetNzShapeByAxisValue)}}; + + mapOfDtypeAndC0 = { + {ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, + {ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, + {ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, + {ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; +} + +bool ShapeTransferAccordingToFormat::GetNDC1HWC0ShapeByAxisValue(vector &new_shape, const int64_t &impl_type, + const std::vector &axis_value, const vector &nd_value) { + CHECK(axis_value.empty(), GELOGD("AxisValue is empty!"), return true); + new_shape.push_back(axis_value[AXIS_N]); + new_shape.push_back(axis_value[AXIS_D]); + new_shape.push_back(axis_value[AXIS_C1]); + new_shape.push_back(axis_value[AXIS_H]); + new_shape.push_back(axis_value[AXIS_W]); + new_shape.push_back(axis_value[AXIS_C0]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_C]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + newShape.push_back(axisValue[AXIS_C]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_C1]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + newShape.push_back(axisValue[AXIS_C0]); + } else { + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_C]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + } + return true; +} + +bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + if (ndValue.size() == SIZE_OF_CN) { + auto sizeOfOriginalVec = ndValue.size(); + newShape = ndValue; + /* sizeOfOriginalVec - 1 mean the last value of original vec + * sizeOfOriginalVec - 2 mean the second last value of original vec */ + newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); + newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); + newShape.push_back(SHAPE_NUMBER_16); + newShape.push_back(axisValue[AXIS_C0]); + } else { + if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { + int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; + newShape.push_back(hwc1); + newShape.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); + newShape.push_back(NI); + newShape.push_back(axisValue[AXIS_C0]); + } else { + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_C]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + } + } + + return true; +} + +bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + newShape.push_back(axisValue[AXIS_C]); + newShape.push_back(axisValue[AXIS_N]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + newShape.push_back(axisValue[AXIS_C1]); + newShape.push_back(axisValue[AXIS_H]); + newShape.push_back(axisValue[AXIS_W]); + newShape.push_back(axisValue[AXIS_N]); + newShape.push_back(axisValue[AXIS_Co]); + newShape.push_back(axisValue[AXIS_C0]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(ndValue.empty(), GELOGD("ndValue is empty!"), return true); + CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, + GELOGD("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); + uint32_t sizeOfOriginalVec = ndValue.size(); + if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { + GELOGD("ndValue's dim num is less than 2!"); + return true; + } + /* axisValue is initialized as a size 6 vector. */ + newShape = ndValue; + + /* sizeOfOriginalVec - 1 mean the last value of original vec + * sizeOfOriginalVec - 2 mean the second last value of original vec */ + newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); + + newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); + newShape.push_back(SHAPE_NUMBER_16); + newShape.push_back(axisValue[AXIS_C0]); + return true; +} + +bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { + /* The default new shape is old shape */ + shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; + if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { + GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, + shapeAndFormatInfo.newFormat); + return false; + } + + if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { + GELOGE(GRAPH_FAILED, "currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); + return false; + } + AxisUtil* axisutil_object = new AxisUtil(); + if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { + delete axisutil_object; + return true; + } + + auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); + if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { + GELOGD("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); + delete axisutil_object; + return true; + } + GELOGD("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); + GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; + CHECK_NOTNULL(getNewShapeFunc); + std::vector axisValue; + for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { + axisValue.push_back(1); + } + std::vector ndValue; + uint32_t c0; + if (mapOfDtypeAndC0.empty()) { + c0 = SHAPE_NUMBER_16; + } else { + auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); + if (iterGetC0 == mapOfDtypeAndC0.end()) { + GELOGE(GRAPH_FAILED, "Dtype is not support."); + delete axisutil_object; + return true; + } + c0 = iterGetC0->second; + } + + // The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 + if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { + c0 = SHAPE_DIM_VALUE_C04; + } + + bool status = axisutil_object->GetAxisValueByOriginFormat( + shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue); + if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { + delete axisutil_object; + return true; + } + delete axisutil_object; + + shapeAndFormatInfo.newShape.clear(); + (*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); + if (c != nullptr) { + *c = axisValue[AXIS_C]; + } + return true; +} +} // namespace transformer +} // namespace common diff --git a/parser b/parser deleted file mode 160000 index 9e392045..00000000 --- a/parser +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9e392045c26a57913b512d0686e1285650b62abe diff --git a/tests/depends/error_manager/src/error_manager_stub.cc b/tests/depends/error_manager/src/error_manager_stub.cc index 4f6b6b3d..edf5a487 100644 --- a/tests/depends/error_manager/src/error_manager_stub.cc +++ b/tests/depends/error_manager/src/error_manager_stub.cc @@ -58,7 +58,7 @@ /// @param [in] value: vector parameter value /// void ErrorManager::ATCReportErrMessage(std::string error_code, const std::vector &key, - const std::vector &value) { + const std::vector &value) { } /// diff --git a/tests/depends/hccl/src/hccl_stub.cc b/tests/depends/hccl/src/hccl_stub.cc index 1cc8fdb3..b9b9d4f6 100644 --- a/tests/depends/hccl/src/hccl_stub.cc +++ b/tests/depends/hccl/src/hccl_stub.cc @@ -19,26 +19,26 @@ #include "hccl/hcom.h" HcclResult hcom_all_gather(const char *tag, void *input_count_ptr, void *output_ptr, u64 input_count, - HcclDataType data_type, const char *group, rtStream_t stream) { + HcclDataType data_type, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType data_type, u32 root, - const char *group, rtStream_t stream) { + const char *group, rtStream_t stream) { return HCCL_SUCCESS; } HcclResult hcom_all_reduce(const char *tag, void *input_ptr, void *output_ptr, u64 count, HcclDataType data_type, - HcclReduceOp op, const char *group, rtStream_t stream) { + HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 max_segment_num, - u32 *segment_num, u32 *segment_idx) { + u32 *segment_num, u32 *segment_idx) { return HCCL_SUCCESS; } HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_ptr, u64 count, - HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { + HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { return HCCL_SUCCESS; } diff --git a/tests/depends/runtime/src/runtime_stub.cc b/tests/depends/runtime/src/runtime_stub.cc index 2ab6684d..75eefdd1 100644 --- a/tests/depends/runtime/src/runtime_stub.cc +++ b/tests/depends/runtime/src/runtime_stub.cc @@ -325,7 +325,7 @@ rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) } rtError_t rtMallocHostSharedMemory(rtMallocHostSharedMemoryIn *in, - rtMallocHostSharedMemoryOut *out) + rtMallocHostSharedMemoryOut *out) { out->ptr = new uint8_t[in->size]; out->devPtr = new uint8_t[in->size]; diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt deleted file mode 100644 index 56babec1..00000000 --- a/tests/st/CMakeLists.txt +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -cmake_minimum_required(VERSION 3.0) -set(CMAKE_CXX_STANDARD 11) -project(ge_st CXX C) - -set(CMAKE_CXX_FLAGS "-O1 -fPIC -Wl,-unresolved-symbols=ignore-in-shared-libs") - - -file(GLOB_RECURSE RES50_TRAIN_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "resnet50/resnet50_train.cc" - "resnet50/common.cc" -) - -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/graph) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/inc/external/ge) -include_directories(${GE_SOURCE_DIR}/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) -include_directories(/usr/local/HiAI/opp/op_proto/built-in/inc) - -add_executable(st_resnet50_train ${RES50_TRAIN_SRCS}) -target_link_libraries(st_resnet50_train - ${PROTOBUF_LIBRARY} - ge_client_train ge_memory -) \ No newline at end of file diff --git a/tests/st/resnet50/common.cc b/tests/st/resnet50/common.cc deleted file mode 100644 index 674ef926..00000000 --- a/tests/st/resnet50/common.cc +++ /dev/null @@ -1,768 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "common.h" -#include "model.h" - -#define MAX_HEAD_SIZE 50 - -using namespace std; -using namespace ge; - -void update_op_format(Operator ops, Format format) { - printf("set format begin.........\n"); - ge::TensorDesc tensor_desc_x = ops.GetInputDesc("x"); - ge::TensorDesc tensor_desc_y = ops.GetOutputDesc("y"); - Format f_x0 = tensor_desc_x.GetFormat(); - Format f_y0 = tensor_desc_x.GetFormat(); - printf("before set x format:%d \n", f_x0); - printf("before set y format:%d \n", f_y0); - printf("format to be set is :%d \n", format); - tensor_desc_x.SetFormat(format); - tensor_desc_y.SetFormat(format); - ops.UpdateInputDesc("x", tensor_desc_x); - ops.UpdateOutputDesc("y", tensor_desc_y); - Format f_x = tensor_desc_x.GetFormat(); - Format f_y = tensor_desc_y.GetFormat(); - printf("after set x format:%d \n", f_x); - printf("after set y format:%d \n", f_y); -} - -/// getDimInfo: get dim info from data file -/// param: -/// fp: the testing datafile object -/// -/// return : -/// dim_info: array to store the info of the dim in datafile, like [4,3,3,6,3,162(3*3*6*3)],4 is dim size,3,3,6,3 is the -/// dim shape data_size: the size of the testing data including the data file -void getDimInfo(FILE *fp, std::vector &dim_info) { - // get dim info from hisi testing data file - uint32_t *dim_buffer = (uint32_t *)malloc(MAX_HEAD_SIZE * sizeof(uint32_t)); - fread(dim_buffer, sizeof(uint32_t), MAX_HEAD_SIZE, fp); - dim_info.push_back(*dim_buffer); // get dim size - - // get data shape to compute the datasize - uint64_t data_size = 1; - uint32_t i = 1; - for (; i <= dim_info[0]; i++) { - dim_info.push_back(*(dim_buffer + i)); - data_size *= *(dim_buffer + i); - } - dim_info.push_back(data_size); - - free(dim_buffer); -} - -/// readTestDataFile: read test date from hisi .t datafile -/// param: -/// infile: the path of hisi .t datafile -/// return: -/// dim_info: array to store the info of the dim in datafile, like [4,3,3,6,3],4 is dim size,3,3,6,3 is the dim shape -void *readTestDataFile(std::string infile, std::vector &dim_info) { - FILE *fp; - fp = fopen(infile.c_str(), "r"); - - if (fp == NULL) { - printf("ERROR: cant't open file %s\n", infile.c_str()); - return NULL; - } else { - getDimInfo(fp, dim_info); - uint64_t data_size = dim_info[dim_info.size() - 1]; - - fclose(fp); - - fp = fopen(infile.c_str(), "r"); - if (fp == NULL) { - printf("ERROR: cant't open file %s\n", infile.c_str()); - return NULL; - } - uint32_t *memory = (uint32_t *)malloc((dim_info[0] + 1 + data_size) * sizeof(uint32_t)); - fread(memory, sizeof(uint32_t), (dim_info[0] + 1 + data_size), fp); - fclose(fp); - return memory + (dim_info[0] + 1); - } -} - -void *readUint8TestDataFile(std::string infile, int size) { - FILE *fp; - fp = fopen(infile.c_str(), "r"); - - if (fp == NULL) { - printf("ERROR: cant't open file %s\n", infile.c_str()); - return NULL; - } - uint8_t *memory = (uint8_t *)malloc((size) * sizeof(uint8_t)); - fread(memory, sizeof(uint8_t), (size), fp); - fclose(fp); - return memory; -} - -/// allclose -/// param: -/// a:compared file a -/// b:compared file b -/// count: the count size which will compare -/// rtol: -/// atol: -/// return: -/// true or false -bool allclose(float *a, float *b, uint64_t count, float rtol = 1e-05, float atol = 1e-08) { - uint32_t i = 0; - - for (; i < count; ++i) { - if (fabs(a[i] - b[i]) > (atol + rtol * fabs(b[i]))) { - printf("compara failed: i= %d, a[i]=%f, b[i]=%f,atol=%f,rtol=%f\n", i, a[i], b[i], atol, rtol); - return false; - } - } - - return true; -} - -/// compFp32WithTData: compare the data with the data in hisi .t file -/// param: -/// actual_output_data: the result of ge -/// expected_data_file: the path of hisi .t result file -/// rtol: -/// atol: -/// return: -/// true of false -bool compFp32WithTData(float *actual_output_data, std::string expected_data_file, float rtol = 1e-05, float atol = 1e-08) { - std::vector dim_info; - float *expected_output_data = (float *)readTestDataFile(expected_data_file, dim_info); - - uint32_t i = 1; - uint64_t data_size = 1; - for (; i <= dim_info[0]; i++) { - data_size *= dim_info[i]; - } - return allclose(actual_output_data, expected_output_data, data_size, rtol, atol); -} - -int SwitchDatatype(DataType dt) { - int size = 1; - if (dt == ge::DT_FLOAT) size = 4; - if (dt == ge::DT_INT32) size = 4; - if (dt == ge::DT_FLOAT16) size = 2; - if (dt == ge::DT_INT64) size = 8; - return size; -} - -ge::Tensor genTensor(std::vector tensor_shape, Format format, DataType dt) { - int size = 1; - for (int i = 0; i < tensor_shape.size(); i++) { - size = size * tensor_shape[i]; - } - - int data_type_size = SwitchDatatype(dt); - - size = abs(size * data_type_size); - vector data_value; - - if (size == 0) { - TensorDesc input_tensor_desc = TensorDesc(ge::Shape(tensor_shape), format, dt); - input_tensor_desc.SetRealDimCnt(tensor_shape.size()); - Tensor gen_tensor = Tensor(input_tensor_desc, data_value); - return gen_tensor; - } - for (int i = 0; i < size; i++) { - data_value.push_back(1); - } - TensorDesc input_tensor_desc = TensorDesc(ge::Shape(tensor_shape), format, dt); - input_tensor_desc.SetRealDimCnt(tensor_shape.size()); - Tensor gen_tensor = Tensor(input_tensor_desc, data_value); - return gen_tensor; -} - -ge::Tensor genTensor_withVaule(std::vector tensor_shape, float value) { - int size = 1; - for (int i = 0; i < tensor_shape.size(); i++) { - size = size * tensor_shape[i]; - } - - float *data_value = new float[size]; - for (int i = 0; i < size; i++) { - *(data_value + i) = value; - } - Tensor gen_ge_tensor; - TensorDesc input_tensor_desc = TensorDesc(ge::Shape(tensor_shape), FORMAT_NCHW); - gen_ge_tensor.SetTensorDesc(input_tensor_desc); - gen_ge_tensor.SetData((uint8_t *)data_value, size * 4); - - return gen_ge_tensor; -} - -Tensor genTesnor_Shape_as_data(std::vector tensor_shape) { - Format format = FORMAT_NCHW; - DataType dt = DT_INT32; - int size = tensor_shape.size(); - int32_t *tensor_data = new int32_t[size]; - std::cout << "shape tensor size:" << size << endl; - for (int i = 0; i < size; i++) { - *(tensor_data + i) = tensor_shape[i]; - } - - Tensor gen_tensor; - TensorDesc input_tensor_desc = TensorDesc(ge::Shape({size}), FORMAT_NCHW, DT_INT32); - gen_tensor.SetData((uint8_t *)tensor_data, size * GetDatTypeSize(dt)); - gen_tensor.SetTensorDesc(input_tensor_desc); - - return gen_tensor; -} - -/// train_flag is 0 when infer; train_flag is 1 when train; train_flag is 0 default -/// run_mode_path is not 0,1,2 when TBE; run_mode_path is 1 when FE; run_mode_path is 0 default -/// run_mode_path is 2 now when AICPU, ge.enabledlocalFmkop is 1 -ge::Status GEInitialize_api(string train_flag, string run_mode_path) { - ge::Status ret; - if (run_mode_path == "0") { - const std::map config = { - {"device_id", "0,2,4,6"}, - {"rank_table_file", "hccl from csa/paas"}, - {"ge.graphRunMode", train_flag}, - {"ge.aicpuFlag", "1"}, - {"ge.feFlag", "1"}, - {DDK_VERSION_FLAG, "1.60.T17.B830"}, - {"ge.soLoadPath", - "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libfe.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/" - "libaicpu_plugin.so"}}; - ret = ge::GEInitialize(config); - } else if (run_mode_path == "1") { - const std::map config = { - {"device_id", "0,2,4,6"}, - {"rank_table_file", "hccl from csa/paas"}, - {"ge.graphRunMode", train_flag}, - {"ge.feFlag", "1"}, - {DDK_VERSION_FLAG, "1.60.T17.B830"}, - {TBE_PLUGIN_PATH_FLAG, "/usr/local/HiAI/runtime/lib64/tbe_plugin/bert"}, - {"ge.soLoadPath", "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libfe.so"}}; - ret = ge::GEInitialize(config); - } else if (run_mode_path == "2") { - const std::map config = {{"device_id", "0,2,4,6"}, - {"rank_table_file", "hccl from csa/paas"}, - {"ge.graphRunMode", train_flag}, - {LOCAL_FMKOP_FLAG, "1"}}; - ret = ge::GEInitialize(config); - } else { - const std::map config = { - {"device_id", "0,2,4,6"}, - {"rank_table_file", "hccl from csa/paas"}, - {"ge.graphRunMode", train_flag}, - {DDK_VERSION_FLAG, "1.60.T17.B830"}, - {TBE_PLUGIN_PATH_FLAG, "/usr/local/HiAI/runtime/lib64/tbe_plugin/" + run_mode_path}}; - ret = ge::GEInitialize(config); - } - std::cout << "GEInitialize_ret is " << ret << std::endl; - - return ret; -} - -/// train_flag is infer default -/// run_mode: is multi group of [fe,aicpu,bert,deeplabv3,mobilenetv2,single_path_nas,ssd] -/// but bert,deeplabv3,mobilenetv2,single_path_nas,ssd can only set one value from array -/// eg:"fe,aicpu,bert" or "fe", default is “fe” -/// "fe,aicpu,bert" remain open fe aicpu and bert -ge::Status GEInitialize_api_new(string train_flag, string run_mode) { - ge::Status ret; - vector modes; - - char *strs = new char[run_mode.length() + 1]; - strcpy(strs, run_mode.c_str()); - const char *delim = ","; - char *p = strtok(strs, delim); - while (p) { - string s = p; // transform substr to string - modes.push_back(s); // save to result array - p = strtok(NULL, delim); - } - - std::map config = { - {"device_id", "0,2,4,6"}, - {"rank_table_file", "hccl from csa/paas"}, - {DDK_VERSION_FLAG, "1.60.T17.B830"}, - {"ge.opsProtoLibPath", "/usr/local/HiAI/runtime/ops/op_proto/built-in/libopsproto.so"}}; - if (train_flag == "infer") - config.insert(pair("ge.graphRunMode", "0")); - else if (train_flag == "train") - config.insert(pair("ge.graphRunMode", "1")); - else - std::cout << "GeInitialize give the error param" << std::endl; - - for (int i = 0; i < modes.size(); i++) { - if (modes[i] == "fe") { - config.insert(pair("ge.feFlag", "1")); - if (config.find("ge.soLoadPath") != config.end()) { - config["ge.soLoadPath"] = - "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libfe.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/" - "libaicpu_plugin.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/libge_local_engine.so:/usr/local/HiAI/" - "runtime/lib64/plugin/opskernel/librts_engine.so"; - } else { - config.insert(pair( - "ge.soLoadPath", - "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libfe.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/" - "libge_local_engine.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/librts_engine.so")); - } - } else if (modes[i] == "aicpu") { - config.insert(pair("ge.aicpuFlag", "1")); - if (config.find("ge.soLoadPath") != config.end()) { - config["ge.soLoadPath"] = - "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libfe.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/" - "libaicpu_plugin.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/libge_local_engine.so:/usr/local/HiAI/" - "runtime/lib64/plugin/opskernel/librts_engine.so"; - } else { - config.insert(pair( - "ge.soLoadPath", - "/usr/local/HiAI/runtime/lib64/plugin/opskernel/libaicpu_plugin.so:/usr/local/HiAI/runtime/lib64/plugin/" - "opskernel/libge_local_engine.so:/usr/local/HiAI/runtime/lib64/plugin/opskernel/librts_engine.so")); - } - } else if (modes[i] == "bert" || modes[i] == "deeplabv3" || modes[i] == "mobilenetv2" || - modes[i] == "single_path_nas" || modes[i] == "ssd") { - config.insert(pair(TBE_PLUGIN_PATH_FLAG, "/usr/local/HiAI/runtime/lib64/tbe_plugin/" + modes[i])); - } else if (modes[i] == "plugin") { - - } else - std::cout << "GeInitialize give the error param" << std::endl; - } - ret = ge::GEInitialize(config); - - std::cout << "GEInitialize_ret is " << ret << std::endl; - - return ret; -} - -ge::Status GEFinalize_api() { - ge::Status ret = ge::GEFinalize(); - std::cout << "GEFinalize ret is " << ret << std::endl; - - return ret; -} - -/// set train_flag -/// if run_mode_path is "fe" remain FE process; "fe,plugin" is FE and TBE plugin process -/// "aicpu" is open aicpu plugin -int RunGraph_initData(Graph &graph, string op_name, map> attr_test, string train_flag, - string run_mode_path) { - std::map options = {{RUN_FLAG, "1"}}; - uint32_t graph_id = 0; - - ge::Status ret = GEInitialize_api_new(train_flag, run_mode_path); - EXPECT_EQ(ret, ge::SUCCESS); - - ge::Session *session = new Session(options); - ASSERT_TRUE(session != NULL); - - std::vector input; - if (attr_test.find("input1") != attr_test.end()) { - Tensor input_tensor = genTensor(attr_test["input1"]); - input.push_back(input_tensor); - } - if (attr_test.find("input2") != attr_test.end()) { - Tensor input_tensor = genTensor(attr_test["input2"]); - input.push_back(input_tensor); - } - if (attr_test.find("input3") != attr_test.end()) { - Tensor input_tensor = genTensor(attr_test["input3"]); - input.push_back(input_tensor); - } - std::vector output; - - ret = session->AddGraph(graph_id, graph); - EXPECT_EQ(ret, ge::SUCCESS); - if (train_flag == "1") { - setenv("GE_TRAIN", "1", true); - ret = session->RunGraph(graph_id, input, output); - setenv("GE_TRAIN", "0", true); - } else { - ret = session->RunGraph(graph_id, input, output); - } - delete session; - GEFinalize_api(); - - if (ret != ge::SUCCESS) { - std::cout << " run graph failed" << std::endl; - return -1; - } else { - return 0; - } -} - -ge::Status session_add_and_run_graph(ge::Session *session, uint32_t graph_id, Graph &graph, std::vector inputs, - std::vector &outputs) { - ge::Status ret = session->AddGraph(graph_id, graph); - EXPECT_EQ(ret, ge::SUCCESS); - ret = session->RunGraph(graph_id, inputs, outputs); - - return ret; -} - -ge::Session *create_session() { - // Init session - std::map options = {{"a", "b"}, {TRAIN_FLAG, "1"}}; - ge::Session *session = new Session(options); - ASSERT_TRUE(session != NULL); - - return session; -} - -ge::Session *create_aipp_session() { - // Init session - std::map options = {{"a", "b"}, {TRAIN_FLAG, "1"}, {"ge.insertOpFile", "/root/host/ge/aipp.cfg"}}; - ge::Session *session = new Session(options); - ASSERT_TRUE(session != NULL); - - return session; -} - -int buildCheckPointGraph(Graph &graph, map variables) { - std::vector inputs{}; - std::vector outputs{}; - - for (map::iterator it = variables.begin(); it != variables.end(); ++it) { - auto var = op::Variable(string(it->first)); - var.update_output_desc_y(it->second); - inputs.push_back(var); - graph.AddOp(var); - } - - auto save = op::Save().create_dynamic_input_tensors(inputs.size()); - for (int i = 0; i < inputs.size(); i++) { - save.set_dynamic_input_tensors(i, inputs[i]); - } - - graph.SetInputs(inputs).SetOutputs(outputs); - return 0; -} - -int buildInitGraph(Graph &graph, std::vector desc_var, std::vector name_var, - std::vector values_var) { - std::vector inputs{}; - std::vector outputs{}; - - for (int i = 0; i < desc_var.size(); i++) { - desc_var[i].SetRealDimCnt(desc_var[i].GetShape().GetDimNum()); - auto tensor_data = genTensor_withVaule(desc_var[i].GetShape().GetDims(), values_var[i]); - auto var_constant = op::Constant().set_attr_value(tensor_data); - var_constant.update_output_desc_y(desc_var[i]); - - auto var_init = op::Variable(string(name_var[i])); - var_init.update_output_desc_y(desc_var[i]); - auto var_assign = op::Assign().set_input_ref(var_init).set_input_value(var_constant); - inputs.push_back(var_init); - } - graph.SetInputs(inputs).SetOutputs(outputs); - return 0; -} - -int buildInitGraph_other_dataType(Graph &graph, std::vector desc_var, std::vector name_var) { - std::vector inputs{}; - std::vector outputs{}; - - for (int i = 0; i < desc_var.size(); i++) { - desc_var[i].SetRealDimCnt(desc_var[i].GetShape().GetDimNum()); - auto tensor_data = genTensor(desc_var[i].GetShape().GetDims(), desc_var[i].GetFormat(), desc_var[i].GetDataType()); - auto var_constant = op::Constant().set_attr_value(tensor_data); - var_constant.update_output_desc_y(desc_var[i]); - - auto var_init = op::Variable(string(name_var[i])); - var_init.update_output_desc_y(desc_var[i]); - auto var_assign = op::Assign().set_input_ref(var_init).set_input_value(var_constant); - inputs.push_back(var_init); - - graph.AddOp(var_constant); - graph.AddOp(var_init); - graph.AddOp(var_assign); - } - graph.SetInputs(inputs).SetOutputs(outputs); - return 0; -} - -bool build_multi_input_multi_output_graph(Graph &graph) { - auto data1 = op::Data("Data1").set_attr_index(0); - auto data2 = op::Data("Data2").set_attr_index(1); - - vector dim_info; - - auto relu1 = op::Relu("Relu1").set_input_x(data1); - auto relu2 = op::Relu("Relu2").set_input_x(data2); - - auto eltwise = op::Eltwise("Eltwise") - .create_dynamic_input_x(2) - .set_dynamic_input_x(0, relu1) - .set_dynamic_input_x(1, relu2) - .set_attr_N(2) - .set_attr_mode(1) - .set_attr_coeff({1, 1}); - - auto eltwise1 = op::Eltwise("Eltwise1") - .create_dynamic_input_x(2) - .set_dynamic_input_x(0, eltwise) - .set_dynamic_input_x(1, eltwise) - .set_attr_N(2) - .set_attr_mode(1) - .set_attr_coeff({1, 1}); - - auto eltwise2 = op::Eltwise("Eltwise2") - .create_dynamic_input_x(2) - .set_dynamic_input_x(0, eltwise) - .set_dynamic_input_x(1, eltwise) - .set_attr_N(2) - .set_attr_mode(1) - .set_attr_coeff({1, 1}); - - std::vector inputs{data1, data2}; - std::vector outputs{eltwise1, eltwise2}; - graph.SetInputs(inputs).SetOutputs(outputs); - return true; -} - -void build_big_graph(Graph &graph, map> attr) { - auto data = op::Data("Data").set_attr_index(0); - auto weight = op::Const("weight1").set_attr_value(genTensor(attr["weight"])); - vector weight_shape(attr["weight"].begin(), attr["weight"].end()); - TensorDesc weight_desc(ge::Shape(weight_shape), FORMAT_NCHW, DT_FLOAT); - weight.update_output_desc_y(weight_desc); - auto conv_1 = op::Conv2D("conv1").set_input_x(data).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - - auto conv_2 = op::Conv2D("conv2").set_input_x(conv_1).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_3 = op::Conv2D("conv3").set_input_x(conv_2).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_4 = op::Conv2D("conv4").set_input_x(conv_3).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_5 = op::Conv2D("conv5").set_input_x(conv_4).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_6 = op::Conv2D("conv6").set_input_x(conv_5).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_7 = op::Conv2D("conv7").set_input_x(conv_6).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_8 = op::Conv2D("conv8").set_input_x(conv_7).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_9 = op::Conv2D("conv9").set_input_x(conv_8).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_10 = op::Conv2D("conv10").set_input_x(conv_9).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_11 = op::Conv2D("conv11").set_input_x(conv_10).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_12 = op::Conv2D("conv12").set_input_x(conv_11).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_13 = op::Conv2D("conv13").set_input_x(conv_12).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_14 = op::Conv2D("conv14").set_input_x(conv_13).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_15 = op::Conv2D("conv15").set_input_x(conv_14).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_16 = op::Conv2D("conv16").set_input_x(conv_15).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_17 = op::Conv2D("conv17").set_input_x(conv_16).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_18 = op::Conv2D("conv18").set_input_x(conv_17).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_19 = op::Conv2D("conv19").set_input_x(conv_18).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_20 = op::Conv2D("conv20").set_input_x(conv_19).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_21 = op::Conv2D("conv21").set_input_x(conv_20).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_22 = op::Conv2D("conv22").set_input_x(conv_21).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_23 = op::Conv2D("conv23").set_input_x(conv_22).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_24 = op::Conv2D("conv24").set_input_x(conv_23).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_25 = op::Conv2D("conv25").set_input_x(conv_24).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_26 = op::Conv2D("conv26").set_input_x(conv_25).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_27 = op::Conv2D("conv27").set_input_x(conv_26).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_28 = op::Conv2D("conv28").set_input_x(conv_27).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_29 = op::Conv2D("conv29").set_input_x(conv_28).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_30 = op::Conv2D("conv30").set_input_x(conv_29).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_31 = op::Conv2D("conv31").set_input_x(conv_30).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_32 = op::Conv2D("conv32").set_input_x(conv_31).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_33 = op::Conv2D("conv33").set_input_x(conv_32).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_34 = op::Conv2D("conv34").set_input_x(conv_33).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_35 = op::Conv2D("conv35").set_input_x(conv_34).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_36 = op::Conv2D("conv36").set_input_x(conv_35).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_37 = op::Conv2D("conv37").set_input_x(conv_36).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_38 = op::Conv2D("conv38").set_input_x(conv_37).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_39 = op::Conv2D("conv39").set_input_x(conv_38).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_40 = op::Conv2D("conv40").set_input_x(conv_39).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_41 = op::Conv2D("conv41").set_input_x(conv_40).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_42 = op::Conv2D("conv42").set_input_x(conv_41).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_43 = op::Conv2D("conv43").set_input_x(conv_42).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_44 = op::Conv2D("conv44").set_input_x(conv_43).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_45 = op::Conv2D("conv45").set_input_x(conv_44).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_46 = op::Conv2D("conv46").set_input_x(conv_45).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_47 = op::Conv2D("conv47").set_input_x(conv_46).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_48 = op::Conv2D("conv48").set_input_x(conv_47).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_49 = op::Conv2D("conv49").set_input_x(conv_48).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_50 = op::Conv2D("conv50").set_input_x(conv_49).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_51 = op::Conv2D("conv51").set_input_x(conv_50).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_52 = op::Conv2D("conv52").set_input_x(conv_51).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_53 = op::Conv2D("conv53").set_input_x(conv_52).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_54 = op::Conv2D("conv54").set_input_x(conv_53).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_55 = op::Conv2D("conv55").set_input_x(conv_54).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_56 = op::Conv2D("conv56").set_input_x(conv_55).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_57 = op::Conv2D("conv57").set_input_x(conv_56).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_58 = op::Conv2D("conv58").set_input_x(conv_57).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_59 = op::Conv2D("conv59").set_input_x(conv_58).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_60 = op::Conv2D("conv60").set_input_x(conv_59).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_61 = op::Conv2D("conv61").set_input_x(conv_60).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_62 = op::Conv2D("conv62").set_input_x(conv_61).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_63 = op::Conv2D("conv63").set_input_x(conv_62).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_64 = op::Conv2D("conv64").set_input_x(conv_63).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_65 = op::Conv2D("conv65").set_input_x(conv_64).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_66 = op::Conv2D("conv66").set_input_x(conv_65).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_67 = op::Conv2D("conv67").set_input_x(conv_66).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_68 = op::Conv2D("conv68").set_input_x(conv_67).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_69 = op::Conv2D("conv69").set_input_x(conv_68).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_70 = op::Conv2D("conv70").set_input_x(conv_69).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_71 = op::Conv2D("conv71").set_input_x(conv_70).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_72 = op::Conv2D("conv72").set_input_x(conv_71).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_73 = op::Conv2D("conv73").set_input_x(conv_72).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_74 = op::Conv2D("conv74").set_input_x(conv_73).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_75 = op::Conv2D("conv75").set_input_x(conv_74).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_76 = op::Conv2D("conv76").set_input_x(conv_75).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_77 = op::Conv2D("conv77").set_input_x(conv_76).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_78 = op::Conv2D("conv78").set_input_x(conv_77).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_79 = op::Conv2D("conv79").set_input_x(conv_78).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_80 = op::Conv2D("conv80").set_input_x(conv_79).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_81 = op::Conv2D("conv81").set_input_x(conv_80).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_82 = op::Conv2D("conv82").set_input_x(conv_81).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_83 = op::Conv2D("conv83").set_input_x(conv_82).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_84 = op::Conv2D("conv84").set_input_x(conv_83).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_85 = op::Conv2D("conv85").set_input_x(conv_84).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_86 = op::Conv2D("conv86").set_input_x(conv_85).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_87 = op::Conv2D("conv87").set_input_x(conv_86).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_88 = op::Conv2D("conv88").set_input_x(conv_87).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_89 = op::Conv2D("conv89").set_input_x(conv_88).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_90 = op::Conv2D("conv90").set_input_x(conv_89).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_91 = op::Conv2D("conv91").set_input_x(conv_80).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_92 = op::Conv2D("conv92").set_input_x(conv_91).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_93 = op::Conv2D("conv93").set_input_x(conv_92).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_94 = op::Conv2D("conv94").set_input_x(conv_93).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_95 = op::Conv2D("conv95").set_input_x(conv_94).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_96 = op::Conv2D("conv96").set_input_x(conv_95).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_97 = op::Conv2D("conv97").set_input_x(conv_96).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_98 = op::Conv2D("conv98").set_input_x(conv_97).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_99 = op::Conv2D("conv99").set_input_x(conv_98).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_100 = op::Conv2D("conv100").set_input_x(conv_99).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_101 = op::Conv2D("conv101").set_input_x(conv_100).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_102 = op::Conv2D("conv102").set_input_x(conv_101).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_103 = op::Conv2D("conv103").set_input_x(conv_102).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_104 = op::Conv2D("conv104").set_input_x(conv_103).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_105 = op::Conv2D("conv105").set_input_x(conv_104).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_106 = op::Conv2D("conv106").set_input_x(conv_105).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_107 = op::Conv2D("conv107").set_input_x(conv_106).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_108 = op::Conv2D("conv108").set_input_x(conv_107).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_109 = op::Conv2D("conv109").set_input_x(conv_108).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_110 = op::Conv2D("conv110").set_input_x(conv_109).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_111 = op::Conv2D("conv111").set_input_x(conv_110).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_112 = op::Conv2D("conv112").set_input_x(conv_111).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_113 = op::Conv2D("conv113").set_input_x(conv_112).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_114 = op::Conv2D("conv114").set_input_x(conv_113).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_115 = op::Conv2D("conv115").set_input_x(conv_114).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_116 = op::Conv2D("conv116").set_input_x(conv_115).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_117 = op::Conv2D("conv117").set_input_x(conv_116).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_118 = op::Conv2D("conv118").set_input_x(conv_117).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_119 = op::Conv2D("conv119").set_input_x(conv_118).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_120 = op::Conv2D("conv120").set_input_x(conv_119).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_121 = op::Conv2D("conv121").set_input_x(conv_120).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_122 = op::Conv2D("conv122").set_input_x(conv_121).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_123 = op::Conv2D("conv123").set_input_x(conv_122).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_124 = op::Conv2D("conv124").set_input_x(conv_123).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_125 = op::Conv2D("conv125").set_input_x(conv_124).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_126 = op::Conv2D("conv126").set_input_x(conv_125).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_127 = op::Conv2D("conv127").set_input_x(conv_126).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_128 = op::Conv2D("conv128").set_input_x(conv_127).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_129 = op::Conv2D("conv129").set_input_x(conv_128).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - auto conv_130 = op::Conv2D("conv130").set_input_x(conv_129).set_input_filter(weight).set_attr_pads({0,0,0,0}).set_attr_strides({1,1,1,1}); - - std::vector inputs{data}; - std::vector outputs{conv_130}; - graph.SetInputs(inputs).SetOutputs(outputs); -} - -int GetDatTypeSize(DataType dt) { - int dailation = 1; - if (dt == ge::DT_FLOAT) - dailation = 4; - else if (dt == ge::DT_FLOAT16) - dailation = 2; - else if (dt == ge::DT_INT16) - dailation = 2; - else if (dt == ge::DT_UINT16) - dailation = 2; - else if (dt == ge::DT_INT32) - dailation = 4; - else if (dt == ge::DT_UINT32) - dailation = 4; - else if (dt == ge::DT_INT64) - dailation = 8; - else if (dt == ge::DT_UINT64) - dailation = 8; - else if (dt == ge::DT_INT8) - dailation = 1; - - return dailation; -} - -int buildConvGraph_new(Graph &graph, std::vector desc_var, std::vector name_var, int flag, - Format format) { - auto data_x_shape = op::Data("xShape").set_attr_index(0); - auto var = op::Variable(name_var[0]); - auto var1 = op::Variable(name_var[1]); //add one seat of ApplyMomentum() - auto label1 = op::Variable(name_var[2]); //add one seat of ApplyMomentum() - auto conv2dgrad = op::Conv2DBackpropFilterD("output_1"); - auto test2 = op::ApplyMomentum(); - - var.update_output_desc_y(desc_var[0]); - var1.update_output_desc_y(desc_var[1]); - label1.update_output_desc_y(desc_var[2]); - - graph.AddOp(var); - graph.AddOp(var1); - graph.AddOp(label1); - - auto conv2d = op::Conv2D().set_input_x(data_x_shape).set_input_filter(var).set_attr_strides({1, 1, 1, 1}).set_attr_pads({0,0,0,0}); - update_op_format(conv2d, format); - ge::TensorDesc tensor_desc_w = conv2d.GetInputDesc("filter"); - tensor_desc_w.SetFormat(format); - conv2d.UpdateInputDesc("filter", tensor_desc_w); - - if (flag >= 1) { - conv2dgrad.set_input_x(data_x_shape) - .set_attr_filter_size(desc_var[0].GetShape().GetDims()) - .set_input_out_backprop(conv2d) - .set_attr_strides({1, 1, 1, 1}) - .set_attr_pads({0, 0, 0, 0}); - update_op_format(conv2dgrad, format); - graph.AddOp(conv2dgrad); - } - if (flag >= 2) { - // set conv2dgrad var - test2.set_input_accum(var1) - .set_input_grad(conv2dgrad) - .set_input_lr(label1) - .set_input_momentum(label1) - .set_input_var(var); - graph.AddOp(test2); - } - - std::vector inputs{data_x_shape}; // set all val - std::vector outputs{conv2d}; - graph.SetInputs(inputs).SetOutputs(outputs); - graph.AddOp(conv2d); - - return 0; -} - -/// load bin data_fail -/// input_path: path of bin data_file -/// shapes: the shape of Tensor -/// ft: the format of Tensor -/// dt: the dataType of Tensor -Tensor load_variable_input_data(string input_path, std::vector shapes, Format ft, DataType dt) { - vector dim_info1; - - uint8_t *input_data = (uint8_t *)readTestDataFile(input_path, dim_info1); // common.h - TensorDesc input_tensor_desc = TensorDesc(ge::Shape(shapes), ft, dt); - input_tensor_desc.SetRealDimCnt(shapes.size()); - Tensor input_tensor = Tensor(input_tensor_desc, input_data, GetDatTypeSize(dt) * dim_info1[dim_info1[0] + 1]); - return input_tensor; -} diff --git a/tests/st/resnet50/common.h b/tests/st/resnet50/common.h deleted file mode 100644 index 75805db7..00000000 --- a/tests/st/resnet50/common.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ST_RESNET50_GE_COMMON_H_ -#define ST_RESNET50_GE_COMMON_H_ -#include "common/ge_inner_error_codes.h" -#include "utils/tensor_utils.h" - -#define MY_USER_GE_LOGI(...) GE_LOG_INFO(1, __VA_ARGS__) -#define MY_USER_GE_LOGW(...) GE_LOG_WARN(1, __VA_ARGS__) -#define MY_USER_GE_LOGE(...) GE_LOG_ERROR(1, 3, __VA_ARGS__) - -#ifndef USER_GE_LOGI -#define USER_GE_LOGI MY_USER_GE_LOGI -#endif // USER_GE_LOGI - -#ifndef USER_GE_LOGW -#define USER_GE_LOGW MY_USER_GE_LOGW -#endif // USER_GE_LOGW - -#ifndef USER_GE_LOGE -#define USER_GE_LOGE MY_USER_GE_LOGE -#endif // USER_GE_LOGE - -/// train_flag is 0 when infer, train_flag is 1 when train.this param is set for RunGranph_readData() and -/// RunGraph_initData() -#define TRAIN_FLAG_INFER "infer" -#define TRAIN_FLAG_TRAIN "train" - -#include -#include -#include -#include -#include -#include -#include - -#include "ge_api.h" -#include "graph.h" -#include "ptest.h" -#include "ops/all_ops.h" -using namespace std; -using namespace ge; - -// read bin file and compile result -void update_op_format(Operator ops, Format format = ge::FORMAT_NCHW); -void getDimInfo(FILE *fp, std::vector &dim_info); -void *readTestDataFile(std::string infile, std::vector &dim_info); -void *readUint8TestDataFile(std::string infile, int size); -bool allclose(float *a, float *b, uint64_t count, float rtol, float atol); -bool compFp32WithTData(float *actual_output_data, std::string expected_data_file, float rtol, float atol); -Tensor load_variable_input_data(string input_path, std::vector shapes, Format ft = ge::FORMAT_NCHW, - DataType dt = ge::DT_FLOAT); -// constructor Tensor -int GetDatTypeSize(DataType dt); -ge::Tensor genTensor(std::vector tensor_shape, Format format = ge::FORMAT_NCHW, DataType dt = ge::DT_FLOAT); -ge::Tensor genTensor_withVaule(std::vector tensor_shape, float value = 1); -Tensor genTesnor_Shape_as_data(std::vector tensor_shape); -// Init GE -ge::Status GEInitialize_api(string train_flag = "0", string run_mode_path = "0"); -ge::Status GEInitialize_api_new(string train_flag = "infer", string run_mode = "fe"); -ge::Status GEFinalize_api(); -// constructor session and build graph -ge::Session *create_aipp_session(); -ge::Session *create_session(); -ge::Status session_add_and_run_graph(ge::Session *session, uint32_t graphId, Graph &graph, std::vector inputs, - std::vector &outputs); - -// common interface for infer -int RunGraph_initData(Graph &graph, string op_name, map> attr_test, - string train_flag = "infer", string run_mode_path = "fe"); -void Inputs_load_Data(string op_name, std::vector &input, map> attr_test, - Format format = ge::FORMAT_NCHW, DataType dt = ge::DT_FLOAT); -bool comparaData(std::vector &output, string op_name, map> attr_test); -int RunGraph_readData(Graph &graph, string op_name, map> attr_test, - string train_flag = "infer", string run_mode_path = "fe", Format format = ge::FORMAT_NCHW, - DataType dt = ge::DT_FLOAT); - -// common interface for train -int buildCheckPointGraph(Graph &graph, map variables); -int buildInitGraph(Graph &graph, std::vector desc_var, std::vector name_var, - std::vector values_var); -int buildInitGraph_other_dataType(Graph &graph, std::vector desc_var, std::vector name_var); - -bool build_multi_input_multi_output_graph(Graph &graph); -void build_big_graph(Graph &graph, map> attr); -int buildConvGraph_new(Graph &graph, std::vector desc_var, std::vector name_var, int flag = 2); - -#endif // ST_RESNET50_GE_COMMON_H_ diff --git a/tests/st/resnet50/ptest.h b/tests/st/resnet50/ptest.h deleted file mode 100644 index 568969f8..00000000 --- a/tests/st/resnet50/ptest.h +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef ST_RESNET50_PTEST_H_ -#define ST_RESNET50_PTEST_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ptest { -class assertion_error : public std::exception { - public: - const char *what() const throw() { return "Assertion Exception"; } -}; - -class TestFixture { - public: - virtual void SetUp() {} - virtual void TearDown() {} - void Run() { _func(); } - void BindFunction(std::function function) { _func = function; } - void SetName(const std::string &name) { _name = name; } - std::string Name() const { return _name; } - virtual ~TestFixture() {} - - private: - std::function _func; - std::string _name; -}; - -enum TestResult { SUCCESS, FAILED, UNAVAILABLE, UNKNOWN, NOCASEFOUND }; - -class TestManager { - public: - static TestManager &GetSingleton() { - static TestManager instance; - return instance; - } - void RegisterTest(const std::string &name, TestFixture *fixture) { _testfixtures[name] = fixture; } - - const std::string GetRunningTestcaseName() const { return _running_testcase_name; } - - const std::list GetAllTestNames() const { - std::list result; - for (auto &t : _testfixtures) { - result.push_back(t.first); - } - return result; - } - - TestResult RunTest(const std::string &name) { - if (_testfixtures.find(name) == _testfixtures.end()) { - return NOCASEFOUND; - } - - _running_testcase_name = name; - - do { - SetTestResult(name, UNKNOWN); - _testfixtures[name]->SetUp(); - if (_testresults[name] == FAILED) { - _testresults[name] = UNAVAILABLE; - break; - } - SetTestResult(name, SUCCESS); - try { - _testfixtures[name]->Run(); - } catch (assertion_error &e) { - // Do nothing as the error has been handled by the TestManager. - } - _testfixtures[name]->TearDown(); - } while (0); - - return _testresults[name]; - } - void SetTestResult(const std::string &name, TestResult result) { _testresults[name] = result; } - TestResult GetTestResult(const std::string &name) { return _testresults[name]; } - - private: - std::map _testfixtures; - std::map _testresults; - std::string _running_testcase_name; -}; - -class TestFixtureRegister { - public: - TestFixtureRegister(const std::string &name, TestFixture *fixture, std::function function) { - fixture->BindFunction(function); - fixture->SetName(name); - TestManager::GetSingleton().RegisterTest(name, fixture); - } -}; -} // namespace ptest - -#define _STR(x) #x -#define _EMPTY_NAMESPACE - -#define _TEST(NAMESPACE, FIXTURECLASS, TESTNAME, CASENAME) \ - void g_func_##TESTNAME##_##CASENAME(void); \ - NAMESPACE::FIXTURECLASS g_fixture_##TESTNAME##_##CASENAME; \ - ptest::TestFixtureRegister g_register_##TESTNAME##_##CASENAME( \ - _STR(TESTNAME##_##CASENAME), &g_fixture_##TESTNAME##_##CASENAME, g_func_##TESTNAME##_##CASENAME); \ - void g_func_##TESTNAME##_##CASENAME(void) - -#define TEST(TESTNAME, CASENAME) _TEST(ptest, TestFixture, TESTNAME, CASENAME) - -#define TEST_F(TESTFIXTURE, CASENAME) _TEST(_EMPTY_NAMESPACE, TESTFIXTURE, TESTFIXTURE, CASENAME) - -#define EXPECT_TRUE(X) \ - do { \ - if (!(X)) { \ - std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \ - ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \ - std::cerr << #X << "Expectation Failed\n" \ - << "Testcase Name: " << test_name << "\n" \ - << "File: " __FILE__ << "\tLine:" << __LINE__ << std::endl; \ - } \ - } while (0); - -// With the macro definition ensures that the compiler can detect compiler warning. -#define Max_Log_Len 1024 -#define PRINT_ERR(lpszFormat, ...) \ - do { \ - char szTmpBuf[Max_Log_Len + 1] = {0}; \ - snprintf(szTmpBuf, Max_Log_Len, lpszFormat, ##__VA_ARGS__); \ - std::cerr << szTmpBuf << std::endl; \ - } while (0) - -// Increase the content of print error messages and error to facilitate rapid analysis -#define EXPECT_TRUE_C(X, ERR_TYPE, format, ...) \ - do { \ - if (!(X)) { \ - std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \ - ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \ - std::cerr << #X << " Expectation Failed." \ - << "Testcase Name: " << test_name << " File:" __FILE__ << " Line:" << __LINE__ << std::endl; \ - PRINT_ERR("[" ERR_TYPE "]" format, ##__VA_ARGS__); \ - } \ - } while (0) - -#define ASSERT_TRUE(X) \ - do { \ - if (!(X)) { \ - std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \ - ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \ - std::cerr << #X << "Assertion Failed\n" \ - << "Testcase Name: " << test_name << "\n" \ - << "File: " __FILE__ << "\tLine:" << __LINE__ << std::endl; \ - throw ptest::assertion_error(); \ - } \ - } while (0); - -// Add printing error information and error line content for quick analysis -#define ASSERT_TRUE_C(X, ERR_TYPE, format, ...) \ - do { \ - if (!(X)) { \ - std::string test_name = ptest::TestManager::GetSingleton().GetRunningTestcaseName(); \ - ptest::TestManager::GetSingleton().SetTestResult(test_name, ptest::FAILED); \ - std::cerr << #X << " Assertion Failed." \ - << "Testcase Name: " << test_name << " File:" __FILE__ << " Line:" << __LINE__ << std::endl; \ - PRINT_ERR("[" ERR_TYPE "]" format, ##__VA_ARGS__); \ - throw ptest::assertion_error(); \ - } \ - } while (0); - -#define CONFIG_ERR "CONFIG_ERR" -#define LOAD_MODEL_ERR "LOAD_MODEL_ERR" -#define FILE_READ_ERR "FILE_READ_ERR" -#define RUN_ERROR "RUN_ERROR" -#define MEM_ERROR "MEM_ERROR" -#define RESULT_ERR "RESULT_ERR" - -#define EXPECT_FALSE(X) EXPECT_TRUE(!(X)) -#define EXPECT_EQ(X, Y) EXPECT_TRUE(((X) == (Y))) -#define EXPECT_NE(X, Y) EXPECT_TRUE(((X) != (Y))) -#define EXPECT_GT(X, Y) EXPECT_TRUE(((X) > (Y))) -#define EXPECT_GE(X, Y) EXPECT_TRUE(((X) >= (Y))) -#define EXPECT_LT(X, Y) EXPECT_TRUE(((X) < (Y))) -#define EXPECT_LE(X, Y) EXPECT_TRUE(((X) <= (Y))) - -#define EXPECT_FALSE_C(X, ERR_TYPE, format, ...) EXPECT_TRUE_C(!(X), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_EQ_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) == (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_NE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) != (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_GT_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) > (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_GE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) >= (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_LT_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) < (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define EXPECT_LE_C(X, Y, ERR_TYPE, format, ...) EXPECT_TRUE_C(((X) <= (Y)), ERR_TYPE, format, ##__VA_ARGS__) - -#define ASSERT_FALSE(X) ASSERT_TRUE(!(X)) -#define ASSERT_EQ(X, Y) ASSERT_TRUE(((X) == (Y))) -#define ASSERT_NE(X, Y) ASSERT_TRUE(((X) != (Y))) -#define ASSERT_GT(X, Y) ASSERT_TRUE(((X) > (Y))) -#define ASSERT_GE(X, Y) ASSERT_TRUE(((X) >= (Y))) -#define ASSERT_LT(X, Y) ASSERT_TRUE(((X) < (Y))) -#define ASSERT_LE(X, Y) ASSERT_TRUE(((X) <= (Y))) - -#define ASSERT_FALSE_C(X, ERR_TYPE, format, ...) ASSERT_TRUE_C(!(X), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_EQ_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) == (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_NE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) != (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_GT_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) > (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_GE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) >= (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_LT_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) < (Y)), ERR_TYPE, format, ##__VA_ARGS__) -#define ASSERT_LE_C(X, Y, ERR_TYPE, format, ...) ASSERT_TRUE_C(((X) <= (Y)), ERR_TYPE, format, ##__VA_ARGS__) - -#endif // ST_RESNET50_PTEST_H_ diff --git a/tests/st/resnet50/resnet50_train.cc b/tests/st/resnet50/resnet50_train.cc deleted file mode 100644 index f1d1e58d..00000000 --- a/tests/st/resnet50/resnet50_train.cc +++ /dev/null @@ -1,852 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "common.h" -#include "ge_api.h" -#include "graph.h" -#include "ops/all_ops.h" -#include "types.h" -#include "utils/tensor_utils.h" - -using namespace std; -using namespace ge; -using namespace op; - -typedef bool (*Func)(Graph &graph); - -#define PADDING_MODE 6 -#define GRAD_PADDING_MODE 3 -vector pad_1{1, 1, 1, 1}; -vector pad_0{0, 0, 0, 0}; -vector stride_1{1, 1}; -vector stride_2{2, 2}; - -// (int out_channels, int h, int w, vector stride{1,1}, vector pad{1,1,1,1}, op::Data() input) -#define GENERATE_CONV_VAR(LAYER, BLK, OPNUM, in_channels, out_channels, h, w, stride, pad, input) \ - auto &LAYER##_##BLK##_##OPNUM##_input = input; \ - \ - TensorDesc LAYER##_##BLK##_##OPNUM##_desc(ge::Shape({out_channels, in_channels, h, w}), FORMAT_NCHW, DT_FLOAT); \ - auto LAYER##_##BLK##_##OPNUM##_weight = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_weight"); \ - LAYER##_##BLK##_##OPNUM##_weight.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_weight = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_weight"); \ - LAYER##_##BLK##_##OPNUM##_mom_weight.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - LAYER##_##BLK##_##OPNUM##_mom_weight.update_input_desc_x(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - cout << string(#LAYER) + string(#BLK) + string(#OPNUM) << "'s weight shape is:" << in_channels << out_channels << h \ - << w << endl; \ - cout << string(#LAYER) + string(#BLK) + string(#OPNUM) \ - << "'s input_x op's shape is:" << input.GetOutputDesc("y").GetShape().GetDim(2) << endl; \ - auto LAYER##_##BLK##_##OPNUM##_tmp_dims = input.GetOutputDesc("y").GetShape().GetDims(); \ - for (auto LAYER##_##BLK##_##OPNUM##_tmp_it = LAYER##_##BLK##_##OPNUM##_tmp_dims.begin(); \ - LAYER##_##BLK##_##OPNUM##_tmp_it != LAYER##_##BLK##_##OPNUM##_tmp_dims.end(); \ - LAYER##_##BLK##_##OPNUM##_tmp_it++) { \ - cout << *LAYER##_##BLK##_##OPNUM##_tmp_it; \ - } \ - cout << endl; \ - \ - auto LAYER##_##BLK##_##OPNUM = op::Conv2D(string(#LAYER) + string(#BLK) + string(#OPNUM)) \ - .set_input_x(input, "y") \ - .set_input_filter(LAYER##_##BLK##_##OPNUM##_weight) \ - .set_attr_strides({1, 1, stride[0], stride[1]}) \ - .set_attr_pads(pad) \ - .set_attr_data_format("NCHW"); \ - update_op_format(LAYER##_##BLK##_##OPNUM); - -#define GENERATE_CONSTANT(LAYER, BLK, OPNUM, CONSTNAME) \ - Tensor LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_tensor; \ - float *LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_data = new float[LAYER##_##BLK##_##OPNUM##_size]; \ - for (int i = 0; i < (int)LAYER##_##BLK##_##OPNUM##_size; i++) { \ - *(LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_data + i) = 0.01; \ - } \ - LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_tensor.SetData((uint8_t *)LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_data, \ - LAYER##_##BLK##_##OPNUM##_size * sizeof(float)); \ - LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_tensor.SetTensorDesc(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_constant = \ - op::Constant().set_attr_value(LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_tensor); \ - LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_constant.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - delete[] LAYER##_##BLK##_##OPNUM##_##CONSTNAME##_data; - -#define GENERATE_CONV_VAR_VAR(LAYER, BLK, OPNUM, in_channels, out_channels, h, w, stride, pad, input) \ - TensorDesc LAYER##_##BLK##_##OPNUM##_desc(ge::Shape({out_channels, in_channels, h, w}), FORMAT_NCHW, DT_FLOAT); \ - uint32_t LAYER##_##BLK##_##OPNUM##_size = LAYER##_##BLK##_##OPNUM##_desc.GetShape().GetShapeSize(); \ - auto LAYER##_##BLK##_##OPNUM##_weight = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_weight"); \ - LAYER##_##BLK##_##OPNUM##_weight.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_weight = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_weight"); \ - LAYER##_##BLK##_##OPNUM##_mom_weight.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, weight); \ - auto LAYER##_##BLK##_##OPNUM##_weight_assign = op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_weight) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_weight_constant); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, mom_weight); \ - auto LAYER##_##BLK##_##OPNUM##_mom_weight_assign = \ - op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_mom_weight) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_mom_weight_constant); \ - \ - input.push_back(LAYER##_##BLK##_##OPNUM##_weight); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_mom_weight); - -// (int out_channels, Operator& input) -#define GENERATE_BN_VAR(LAYER, BLK, OPNUM, out_channels, input) \ - auto &LAYER##_##BLK##_##OPNUM##_input = input; \ - \ - TensorDesc LAYER##_##BLK##_##OPNUM##_desc(ge::Shape({1, out_channels, 1, 1}), FORMAT_NCHW, DT_FLOAT); \ - auto LAYER##_##BLK##_##OPNUM##_scale = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_scale"); \ - LAYER##_##BLK##_##OPNUM##_scale.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_scale = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_scale"); \ - LAYER##_##BLK##_##OPNUM##_mom_scale.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_b = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_b"); \ - LAYER##_##BLK##_##OPNUM##_b.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_b = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_b"); \ - LAYER##_##BLK##_##OPNUM##_mom_b.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mean = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mean"); \ - LAYER##_##BLK##_##OPNUM##_mean.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - auto LAYER##_##BLK##_##OPNUM##_variance = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_variance"); \ - LAYER##_##BLK##_##OPNUM##_variance.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM = op::FusedBatchNorm(string(#LAYER) + string(#BLK) + string(#OPNUM)) \ - .set_input_x(input, "y") \ - .set_input_scale(LAYER##_##BLK##_##OPNUM##_scale) \ - .set_input_b(LAYER##_##BLK##_##OPNUM##_b) \ - .set_input_mean(LAYER##_##BLK##_##OPNUM##_mean) \ - .set_input_variance(LAYER##_##BLK##_##OPNUM##_variance) \ - .set_attr_mode(1) \ - .set_attr_epsilon(1e-5) \ - .set_attr_is_training(true); - -#define GENERATE_BN_VAR_VAR(LAYER, BLK, OPNUM, out_channels, input) \ - TensorDesc LAYER##_##BLK##_##OPNUM##_desc(ge::Shape({1, out_channels, 1, 1}), FORMAT_NCHW, DT_FLOAT); \ - uint32_t LAYER##_##BLK##_##OPNUM##_size = LAYER##_##BLK##_##OPNUM##_desc.GetShape().GetShapeSize(); \ - auto LAYER##_##BLK##_##OPNUM##_scale = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_scale"); \ - LAYER##_##BLK##_##OPNUM##_scale.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_scale = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_scale"); \ - LAYER##_##BLK##_##OPNUM##_mom_scale.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_b = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_b"); \ - LAYER##_##BLK##_##OPNUM##_b.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_b = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mom_b"); \ - LAYER##_##BLK##_##OPNUM##_mom_b.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mean = op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_mean"); \ - LAYER##_##BLK##_##OPNUM##_mean.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - auto LAYER##_##BLK##_##OPNUM##_variance = \ - op::Variable(string(#LAYER) + string(#BLK) + string(#OPNUM) + "_variance"); \ - LAYER##_##BLK##_##OPNUM##_variance.update_output_desc_y(LAYER##_##BLK##_##OPNUM##_desc); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, scale); \ - \ - auto LAYER##_##BLK##_##OPNUM##_scale_assign = op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_scale) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_scale_constant); \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, mom_scale); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_scale_assign = \ - op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_mom_scale) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_mom_scale_constant); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, b); \ - \ - auto LAYER##_##BLK##_##OPNUM##_b_assign = \ - op::Assign().set_input_ref(LAYER##_##BLK##_##OPNUM##_b).set_input_value(LAYER##_##BLK##_##OPNUM##_b_constant); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, mom_b); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mom_b_assign = op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_mom_b) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_mom_b_constant); \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, mean); \ - \ - auto LAYER##_##BLK##_##OPNUM##_mean_assign = op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_mean) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_mean_constant); \ - \ - GENERATE_CONSTANT(LAYER, BLK, OPNUM, variance); \ - \ - auto LAYER##_##BLK##_##OPNUM##_variance_assign = op::Assign() \ - .set_input_ref(LAYER##_##BLK##_##OPNUM##_variance) \ - .set_input_value(LAYER##_##BLK##_##OPNUM##_variance_constant); \ - \ - input.push_back(LAYER##_##BLK##_##OPNUM##_scale); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_mom_scale); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_b); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_mom_b); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_mean); \ - input.push_back(LAYER##_##BLK##_##OPNUM##_variance); - -// (int out_channels, Operator& input) -#define GENERATE_RELU_VAR(LAYER, BLK, OPNUM, input) \ - auto &LAYER##_##BLK##_##OPNUM##_input = input; \ - auto LAYER##_##BLK##_##OPNUM = op::Relu(string(#LAYER) + string(#BLK) + string(#OPNUM)).set_input_x(input, "y"); - -// (int out_channels, Operator& input) -#define GENERATE_MAXPOOL_VAR(LAYER, BLK, OPNUM, input) \ - auto &LAYER##_##BLK##_##OPNUM##_input = input; \ - \ - auto LAYER##_##BLK##_##OPNUM = op::MaxPoolWithArgmax(string(#LAYER) + string(#BLK) + string(#OPNUM)) \ - .set_input_x(input, "y") \ - .set_attr_ksize({1, 3, 3, 1}) \ - .set_attr_padding("SAME") \ - .set_attr_strides({1, 2, 2, 1}); - -// (int out_channels, Operator& input) -#define GENERATE_ADD_VAR(LAYER, BLK, OPNUM, input_x1, input_x2) \ - auto LAYER##_##BLK##_##OPNUM = \ - op::Add(string(#LAYER) + string(#BLK) + string(#OPNUM)).set_input_x1(input_x1, "y").set_input_x2(input_x2, "y"); - -// (int in_channels, int out_channels,vector stride{1,1}, Operator& input) -#define MAKE_RESIDUAL_BLOCK(LAYER, BLK, in_channels, out_channels, stride, input) \ - auto &LAYER##_##BLK##_input = input; \ - auto &LAYER##_##BLK##_stride = stride; \ - int LAYER##_##BLK##_out_chls = out_channels / 4; \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv1, in_channels, LAYER##_##BLK##_out_chls, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR(LAYER, BLK, bn1, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_conv1); \ - GENERATE_RELU_VAR(LAYER, BLK, relu1, LAYER##_##BLK##_bn1); \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_out_chls, 3, 3, stride_1, pad_1, \ - LAYER##_##BLK##_relu1); \ - GENERATE_BN_VAR(LAYER, BLK, bn2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_conv2); \ - GENERATE_RELU_VAR(LAYER, BLK, relu2, LAYER##_##BLK##_bn2); \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv3, LAYER##_##BLK##_out_chls, out_channels, 1, 1, stride_1, pad_0, \ - LAYER##_##BLK##_relu2); \ - GENERATE_BN_VAR(LAYER, BLK, bn3, out_channels, LAYER##_##BLK##_conv3); \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv4, in_channels, out_channels, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR(LAYER, BLK, bn4, out_channels, LAYER##_##BLK##_conv4); \ - \ - GENERATE_ADD_VAR(LAYER, BLK, add5, LAYER##_##BLK##_bn3, LAYER##_##BLK##_bn4); \ - GENERATE_RELU_VAR(LAYER, BLK, relu5, LAYER##_##BLK##_add5); \ - \ - auto &LAYER##_##BLK##_output = LAYER##_##BLK##_relu5; \ - auto &LAYER##_##BLK##_output_label = "y"; - -#define MAKE_RESIDUAL_BLOCK_VAR(LAYER, BLK, in_channels, out_channels, stride, input) \ - int LAYER##_##BLK##_out_chls = out_channels / 4; \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv1, in_channels, LAYER##_##BLK##_out_chls, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn1, LAYER##_##BLK##_out_chls, input); \ - \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_out_chls, 3, 3, stride_1, pad_1, \ - input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn2, LAYER##_##BLK##_out_chls, input); \ - \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv3, LAYER##_##BLK##_out_chls, out_channels, 1, 1, stride_1, pad_0, input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn3, out_channels, input); \ - \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv4, in_channels, out_channels, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn4, out_channels, input); - -// (int in_channels, int out_channels,vector stride{1,1}, Operator& input) -#define MAKE_NORMAL_BLOCK(LAYER, BLK, in_channels, out_channels, stride, input) \ - auto &LAYER##_##BLK##_input = input; \ - auto &LAYER##_##BLK##_stride = stride; \ - int LAYER##_##BLK##_out_chls = out_channels / 4; \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv1, in_channels, LAYER##_##BLK##_out_chls, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR(LAYER, BLK, bn1, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_conv1); \ - GENERATE_RELU_VAR(LAYER, BLK, relu1, LAYER##_##BLK##_bn1); \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_out_chls, 3, 3, stride_1, pad_1, \ - LAYER##_##BLK##_relu1); \ - GENERATE_BN_VAR(LAYER, BLK, bn2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_conv2); \ - GENERATE_RELU_VAR(LAYER, BLK, relu2, LAYER##_##BLK##_bn2); \ - \ - GENERATE_CONV_VAR(LAYER, BLK, conv3, LAYER##_##BLK##_out_chls, out_channels, 1, 1, stride_1, pad_0, \ - LAYER##_##BLK##_relu2); \ - GENERATE_BN_VAR(LAYER, BLK, bn3, out_channels, LAYER##_##BLK##_conv3); \ - \ - GENERATE_ADD_VAR(LAYER, BLK, add5, LAYER##_##BLK##_bn3, input); \ - GENERATE_RELU_VAR(LAYER, BLK, relu5, LAYER##_##BLK##_add5); \ - \ - auto &LAYER##_##BLK##_output = LAYER##_##BLK##_relu5; \ - auto &LAYER##_##BLK##_output_label = "y"; - -#define MAKE_NORMAL_BLOCK_VAR(LAYER, BLK, in_channels, out_channels, stride, input) \ - int LAYER##_##BLK##_out_chls = out_channels / 4; \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv1, in_channels, LAYER##_##BLK##_out_chls, 1, 1, stride, pad_0, input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn1, LAYER##_##BLK##_out_chls, input); \ - \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv2, LAYER##_##BLK##_out_chls, LAYER##_##BLK##_out_chls, 3, 3, stride_1, pad_1, \ - input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn2, LAYER##_##BLK##_out_chls, input); \ - \ - GENERATE_CONV_VAR_VAR(LAYER, BLK, conv3, LAYER##_##BLK##_out_chls, out_channels, 1, 1, stride_1, pad_0, input); \ - GENERATE_BN_VAR_VAR(LAYER, BLK, bn3, out_channels, input); - -// (int in_channels, int out_channels,vector stride{1,1}, Operator& input) -#define MAKE_RESIDUAL_LAYER(LAYER, in_channels, out_channels, stride, input) \ - MAKE_RESIDUAL_BLOCK(LAYER, blk1, in_channels, out_channels, stride, input); \ - \ - auto &LAYER##_output = LAYER##_blk1_output; \ - auto &LAYER##_output_label = LAYER##_blk1_output_label; - -#define MAKE_RESIDUAL_LAYER_VAR(LAYER, in_channels, out_channels, stride, input) \ - MAKE_RESIDUAL_BLOCK_VAR(LAYER, blk1, in_channels, out_channels, stride, input); - -// (int in_channels, int out_channels,vector stride{1,1}, Operator& input) -#define MAKE_NORMAL_LAYER(LAYER, in_channels, out_channels, stride, input) \ - MAKE_NORMAL_BLOCK(LAYER, blk1, in_channels, out_channels, stride, input); \ - \ - auto &LAYER##_output = LAYER##_blk1_output; \ - auto &LAYER##_output_label = LAYER##_blk1_output_label; - -#define MAKE_NORMAL_LAYER_VAR(LAYER, in_channels, out_channels, stride, input) \ - MAKE_NORMAL_BLOCK_VAR(LAYER, blk1, in_channels, out_channels, stride, input); - -#define MAKE_RESNET50(input) \ - MAKE_RESIDUAL_LAYER(layer1, 64, 256, stride_1, input) \ - MAKE_NORMAL_LAYER(layer2, 256, 256, stride_1, layer1_output) \ - MAKE_NORMAL_LAYER(layer3, 256, 256, stride_1, layer2_output) \ - MAKE_RESIDUAL_LAYER(layer4, 256, 512, stride_2, layer3_output) \ - MAKE_NORMAL_LAYER(layer5, 512, 512, stride_1, layer4_output) \ - MAKE_NORMAL_LAYER(layer6, 512, 512, stride_1, layer5_output) \ - MAKE_NORMAL_LAYER(layer7, 512, 512, stride_1, layer6_output) \ - MAKE_RESIDUAL_LAYER(layer8, 512, 1024, stride_2, layer7_output) \ - MAKE_NORMAL_LAYER(layer9, 1024, 1024, stride_1, layer8_output) \ - MAKE_NORMAL_LAYER(layer10, 1024, 1024, stride_1, layer9_output) \ - MAKE_NORMAL_LAYER(layer11, 1024, 1024, stride_1, layer10_output) \ - MAKE_NORMAL_LAYER(layer12, 1024, 1024, stride_1, layer11_output) \ - MAKE_NORMAL_LAYER(layer13, 1024, 1024, stride_1, layer12_output) \ - MAKE_RESIDUAL_LAYER(layer14, 1024, 2048, stride_2, layer13_output) \ - MAKE_NORMAL_LAYER(layer15, 2048, 2048, stride_1, layer14_output) \ - MAKE_NORMAL_LAYER(layer16, 2048, 2048, stride_1, layer15_output) \ - \ - auto &resnet50_output = layer16_output; \ - auto &resnet50_output_label = layer16_output_label; - -#define MAKE_RESNET50_VAR(inputs) \ - MAKE_RESIDUAL_LAYER_VAR(layer1, 64, 256, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer2, 256, 256, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer3, 256, 256, stride_1, inputs) \ - MAKE_RESIDUAL_LAYER_VAR(layer4, 256, 512, stride_2, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer5, 512, 512, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer6, 512, 512, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer7, 512, 512, stride_1, inputs) \ - MAKE_RESIDUAL_LAYER_VAR(layer8, 512, 1024, stride_2, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer9, 1024, 1024, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer10, 1024, 1024, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer11, 1024, 1024, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer12, 1024, 1024, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer13, 1024, 1024, stride_1, inputs) \ - MAKE_RESIDUAL_LAYER_VAR(layer14, 1024, 2048, stride_2, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer15, 2048, 2048, stride_1, inputs) \ - MAKE_NORMAL_LAYER_VAR(layer16, 2048, 2048, stride_1, inputs) \ -//--------------------------------------------------------------------------------------------- - -// (Operator& input) -#define GENERATE_BIASADD_GRAD(LAYER, BLK, OPNUM, input) \ - auto LAYER##_##BLK##_##OPNUM##_grad = \ - op::BiasAddGrad(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")) \ - .set_input_x(input, input.name_out_dx()); - -// (Operator& input) -#define GENERATE_MATMUL_GRAD(LAYER, BLK, OPNUM, input) \ - auto LAYER##_##BLK##_##OPNUM##_grad = \ - op::MatMul(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")).set_input_x1(input); - -// (Operator& input) -#define GENERATE_RESHAPE_GRAD(LAYER, BLK, OPNUM, input) \ - auto LAYER##_##BLK##_##OPNUM##_grad = \ - op::Reshape(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")).set_input_tensor(input); - -// (Operator& input_grad, Operator& input_maxpool) -#define GENERATE_MAXPOOL_GRAD(LAYER, BLK, OPNUM, input_grad, input_maxpool) \ - auto LAYER##_##BLK##_##OPNUM##_grad = \ - op::MaxPoolGradWithArgmax(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")) \ - .set_input_x(LAYER##_##BLK##_##OPNUM##_input, "y") \ - .set_input_grad(input_grad) \ - .set_input_argmax(input_maxpool, input_maxpool.name_out_argmax()) \ - .set_attr_ksize({1, 1, 3, 3}) \ - .set_attr_strides({1, 1, 2, 2}) \ - .set_attr_padding("SAME"); - -// (Operator& input_dy) -#define GENERATE_RELU_GRAD(LAYER, BLK, OPNUM, input_dy, dy_label) \ - auto LAYER##_##BLK##_##OPNUM##_grad = op::ReluGrad(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")) \ - .set_input_gradients(input_dy, dy_label) \ - .set_input_features(LAYER##_##BLK##_##OPNUM, "y"); - -// (Operator& input_dy) -#define GENERATE_BN_GRAD(LAYER, BLK, OPNUM, input_dy) \ - auto LAYER##_##BLK##_##OPNUM##_grad = \ - op::FusedBatchNormGrad(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")) \ - .set_input_dy(input_dy, "backprops") \ - .set_input_x(LAYER##_##BLK##_##OPNUM##_input, "y") \ - .set_input_scale(LAYER##_##BLK##_##OPNUM##_scale) \ - .set_input_save_mean(LAYER##_##BLK##_##OPNUM, "save_mean") \ - .set_input_save_inv_variance(LAYER##_##BLK##_##OPNUM, "save_inv_variance") \ - .set_attr_epsilon(0.0001); \ - \ - auto LAYER##_##BLK##_##OPNUM##_momentum_scale = \ - op::ApplyMomentum() \ - .set_input_accum(LAYER##_##BLK##_##OPNUM##_mom_scale) \ - .set_input_grad(LAYER##_##BLK##_##OPNUM##_grad, LAYER##_##BLK##_##OPNUM##_grad.name_out_bn_scale()) \ - .set_input_lr(label1) \ - .set_input_momentum(label1) \ - .set_input_var(LAYER##_##BLK##_##OPNUM##_scale); \ - \ - auto LAYER##_##BLK##_##OPNUM##_momentum_b = \ - op::ApplyMomentum() \ - .set_input_accum(LAYER##_##BLK##_##OPNUM##_mom_b) \ - .set_input_grad(LAYER##_##BLK##_##OPNUM##_grad, LAYER##_##BLK##_##OPNUM##_grad.name_out_bn_bias()) \ - .set_input_lr(label1) \ - .set_input_momentum(label1) \ - .set_input_var(LAYER##_##BLK##_##OPNUM##_b); - -// (Operator& input) -#define GENERATE_CONV_PROP_FILTER(LAYER, BLK, OPNUM, input_bngrad, stride) \ - auto LAYER##_##BLK##_##OPNUM##_propfilter = \ - op::Conv2DBackpropFilterD(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("_propfilter")) \ - .set_input_x(LAYER##_##BLK##_##OPNUM##_input, "y") \ - .set_attr_filter_size(LAYER##_##BLK##_##OPNUM##_desc.GetShape().GetDims()) \ - .set_input_out_backprop(input_bngrad, input_bngrad.name_out_dx()) \ - .set_attr_strides(stride) \ - .set_attr_pads({1, 1, 1, 1}); \ - \ - update_op_format(LAYER##_##BLK##_##OPNUM##_propfilter); \ - auto LAYER##_##BLK##_##OPNUM##_momentum_weight = op::ApplyMomentum() \ - .set_input_accum(LAYER##_##BLK##_##OPNUM##_mom_weight) \ - .set_input_grad(LAYER##_##BLK##_##OPNUM##_propfilter) \ - .set_input_lr(label1) \ - .set_input_momentum(label1) \ - .set_input_var(LAYER##_##BLK##_##OPNUM##_weight); - -///.set_attr_input_size({input_bngrad.name_out_dx().GetOutputDesc().GetShape().GetDim(0),LAYER##_##BLK##_##OPNUM##_weight.GetOutputDesc().GetShape().GetDim(1), -///input_bngrad.name_out_dx().GetOutputDesc().GetShape().GetDim(2)*stride[2], -///input_bngrad.name_out_dx().GetOutputDesc().GetShape().GetDim(3)*stride[3]}) -#define GENERATE_CONV_PROP_INPUT(LAYER, BLK, OPNUM, input_bngrad, stride) \ - auto LAYER##_##BLK##_##OPNUM##_propinput = \ - op::Conv2DBackpropInputD(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("_propinput")) \ - .set_attr_input_size(LAYER##_##BLK##_##OPNUM##_input.GetOutputDesc("y").GetShape().GetDims()) \ - .set_input_filter(LAYER##_##BLK##_##OPNUM##_weight) \ - .set_input_out_backprop(input_bngrad, input_bngrad.name_out_dx()) \ - .set_attr_strides(stride) \ - .set_attr_pads({1, 1, 1, 1}); \ - cout << string(#LAYER) + string(#BLK) + string(#OPNUM) + "_propinput" \ - << "'s input_x op's shape is:" << input_bngrad.GetOutputDesc("dx").GetShape().GetDim(3) * stride[3] << endl; \ - cout << string(#LAYER) + string(#BLK) + string(#OPNUM) + "_propinput" \ - << "'s input_x op's shape is:" << input_bngrad.GetOutputDesc("dx").GetShape().GetDim(2) * stride[2] << endl; \ - \ - update_op_format(LAYER##_##BLK##_##OPNUM##_propinput); \ - auto &LAYER##_##BLK##_##OPNUM##_propinput_label = "y" - -// (int out_channels, Operator& input) -#define GENERATE_ADD_GRAD(LAYER, BLK, OPNUM, input_x1, input_x1_label, input_x2, input_x2_label) \ - auto LAYER##_##BLK##_##OPNUM##_grad = op::Add(string(#LAYER) + string(#BLK) + string(#OPNUM) + string("grad")) \ - .set_input_x1(input_x1, input_x1_label) \ - .set_input_x2(input_x2, input_x2_label); - -// (Operator& input) -#define MAKE_RESIDUAL_BLOCK_GRAD(LAYER, BLK, input_dy, dy_label) \ - GENERATE_RELU_GRAD(LAYER, BLK, relu5, input_dy, dy_label); \ - \ - GENERATE_BN_GRAD(LAYER, BLK, bn4, LAYER##_##BLK##_relu5_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv4, LAYER##_##BLK##_bn4_grad, LAYER##_##BLK##_stride); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv4, LAYER##_##BLK##_bn4_grad, LAYER##_##BLK##_stride); \ - \ - GENERATE_BN_GRAD(LAYER, BLK, bn3, LAYER##_##BLK##_relu5_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv3, LAYER##_##BLK##_bn3_grad, stride_1); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv3, LAYER##_##BLK##_bn3_grad, stride_1); \ - \ - GENERATE_RELU_GRAD(LAYER, BLK, relu2, LAYER##_##BLK##_conv3_propinput, "y"); \ - GENERATE_BN_GRAD(LAYER, BLK, bn2, LAYER##_##BLK##_relu2_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv2, LAYER##_##BLK##_bn2_grad, stride_1); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv2, LAYER##_##BLK##_bn2_grad, stride_1); \ - \ - GENERATE_RELU_GRAD(LAYER, BLK, relu1, LAYER##_##BLK##_conv2_propinput, "y"); \ - GENERATE_BN_GRAD(LAYER, BLK, bn1, LAYER##_##BLK##_relu1_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv1, LAYER##_##BLK##_bn1_grad, LAYER##_##BLK##_stride); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv1, LAYER##_##BLK##_bn1_grad, LAYER##_##BLK##_stride); \ - \ - GENERATE_ADD_GRAD(LAYER, BLK, add5, LAYER##_##BLK##_conv1_propinput, LAYER##_##BLK##_conv1_propinput_label, \ - LAYER##_##BLK##_conv4_propinput, LAYER##_##BLK##_conv4_propinput_label); \ - \ - auto &LAYER##_##BLK##_grad_output = LAYER##_##BLK##_add5_grad; \ - auto &LAYER##_##BLK##_grad_output_label = "y" - -// (Operator& input) -#define MAKE_NORMAL_BLOCK_GRAD(LAYER, BLK, input_dy, dy_label) \ - GENERATE_RELU_GRAD(LAYER, BLK, relu5, input_dy, dy_label); \ - \ - GENERATE_BN_GRAD(LAYER, BLK, bn3, LAYER##_##BLK##_relu5_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv3, LAYER##_##BLK##_bn3_grad, stride_1); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv3, LAYER##_##BLK##_bn3_grad, stride_1); \ - \ - GENERATE_RELU_GRAD(LAYER, BLK, relu2, LAYER##_##BLK##_conv3_propinput, "y"); \ - GENERATE_BN_GRAD(LAYER, BLK, bn2, LAYER##_##BLK##_relu2_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv2, LAYER##_##BLK##_bn2_grad, stride_1); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv2, LAYER##_##BLK##_bn2_grad, stride_1); \ - \ - GENERATE_RELU_GRAD(LAYER, BLK, relu1, LAYER##_##BLK##_conv2_propinput, "y"); \ - GENERATE_BN_GRAD(LAYER, BLK, bn1, LAYER##_##BLK##_relu1_grad); \ - GENERATE_CONV_PROP_FILTER(LAYER, BLK, conv1, LAYER##_##BLK##_bn1_grad, LAYER##_##BLK##_stride); \ - GENERATE_CONV_PROP_INPUT(LAYER, BLK, conv1, LAYER##_##BLK##_bn1_grad, LAYER##_##BLK##_stride); \ - \ - GENERATE_ADD_GRAD(LAYER, BLK, add5, LAYER##_##BLK##_conv1_propinput, LAYER##_##BLK##_conv1_propinput_label, \ - input_dy, dy_label); \ - \ - auto &LAYER##_##BLK##_grad_output = LAYER##_##BLK##_add5_grad; \ - auto &LAYER##_##BLK##_grad_output_label = "y" - -// (Operator& input_dy) -#define MAKE_RESIDUAL_LAYER_GRAD(LAYER, input_dy, dy_label) \ - MAKE_RESIDUAL_BLOCK_GRAD(LAYER, blk1, input_dy, dy_label); \ - \ - auto &LAYER##_grad_output = LAYER##_blk1_grad_output; \ - auto &LAYER##_grad_output_label = LAYER##_blk1_grad_output_label; - -// (Operator& input_dy) -#define MAKE_NORMAL_LAYER_GRAD(LAYER, input_dy, dy_label) \ - MAKE_NORMAL_BLOCK_GRAD(LAYER, blk1, input_dy, dy_label); \ - \ - auto &LAYER##_grad_output = LAYER##_blk1_grad_output; \ - auto &LAYER##_grad_output_label = LAYER##_blk1_grad_output_label; - -#define MAKE_RESNET50_GRAD(input_dy, dy_label) \ - MAKE_NORMAL_LAYER_GRAD(layer16, input_dy, dy_label) \ - MAKE_NORMAL_LAYER_GRAD(layer15, layer16_grad_output, layer16_grad_output_label) \ - MAKE_RESIDUAL_LAYER_GRAD(layer14, layer15_grad_output, layer15_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer13, layer14_grad_output, layer14_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer12, layer13_grad_output, layer13_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer11, layer12_grad_output, layer12_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer10, layer11_grad_output, layer11_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer9, layer10_grad_output, layer10_grad_output_label) \ - MAKE_RESIDUAL_LAYER_GRAD(layer8, layer9_grad_output, layer9_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer7, layer8_grad_output, layer8_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer6, layer7_grad_output, layer7_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer5, layer6_grad_output, layer6_grad_output_label) \ - MAKE_RESIDUAL_LAYER_GRAD(layer4, layer5_grad_output, layer5_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer3, layer4_grad_output, layer4_grad_output_label) \ - MAKE_NORMAL_LAYER_GRAD(layer2, layer3_grad_output, layer3_grad_output_label) \ - MAKE_RESIDUAL_LAYER_GRAD(layer1, layer2_grad_output, layer2_grad_output_label) \ - \ - auto &resnet50_grad_output = layer1_grad_output; \ - auto &resnet50_grad_output_label = layer1_grad_output_label; - -bool resnet50(Graph &graph) { - auto data = op::Data().set_attr_index(0); - auto data1 = op::Data().set_attr_index(1); - TensorDesc shape_desc(ge::Shape({32, 3, 224, 224}), FORMAT_NCHW, DT_FLOAT); - data.update_output_desc_y(shape_desc); - - TensorDesc desc(ge::Shape({64, 3, 7, 7}), FORMAT_NCHW, DT_FLOAT); - - auto var = op::Variable("conv2d_var"); - var.update_output_desc_y(desc); - var.update_input_desc_x(desc); - - auto varw1 = op::Variable("conv2d_varw1"); - varw1.update_output_desc_y(desc); - - auto conv2d = op::Conv2D("Translate") - .set_input_x(data) - .set_input_filter(var) - .set_attr_strides({1, 1, 2, 2}) - .set_attr_pads({2, 3, 2, 3}) - .set_attr_data_format("NCHW"); - TensorDesc desc_y; - desc_y.SetFormat(FORMAT_NCHW); // shape: 32 64 112 112 - conv2d.update_output_desc_y(desc_y); - - TensorDesc desc1(ge::Shape({1, 64, 1, 1}), FORMAT_NCHW, DT_FLOAT); - auto var1 = op::Variable("bn_var1"); - var1.update_output_desc_y(desc1); - - auto var2 = op::Variable("bn_var2"); - var2.update_output_desc_y(desc1); - - auto var3 = op::Variable("bn_var3"); - var3.update_output_desc_y(desc1); - - auto var4 = op::Variable("bn_var4"); - var4.update_output_desc_y(desc1); - - TensorDesc desc2(ge::Shape({2048, 1001}), FORMAT_NCHW, DT_FLOAT); - - auto var5 = op::Variable("var5"); - var5.update_output_desc_y(desc2); - - auto var6 = op::Variable("var6"); - var6.update_output_desc_y(desc2); - - TensorDesc desclabel(ge::Shape({1, 1001, 1, 1}), FORMAT_NCHW, DT_FLOAT); - - auto label1 = op::Variable("label1"); - label1.update_output_desc_y(desclabel); - - TensorDesc descmatlabel(ge::Shape({1, 1001, 1, 1}), FORMAT_NCHW, DT_FLOAT); - auto matvar = op::Variable("matvar"); - matvar.update_output_desc_y(descmatlabel); - - auto matvar1 = op::Variable("matvar1"); - matvar1.update_output_desc_y(descmatlabel); - - auto bn = op::FusedBatchNorm() - .set_input_x(conv2d, "y") - .set_input_scale(var1) - .set_input_b(var2) - .set_input_mean(var3) - .set_input_variance(var4) - .set_attr_mode(1) - .set_attr_epsilon(1e-5) - .set_attr_is_training(true) - .set_attr_is_training_fusion(true) - .set_attr_moving_average_fraction(994352128); - - auto relu = op::Relu().set_input_x(bn, "y"); - - auto maxpool = op::MaxPoolWithArgmax() - .set_input_x(relu, "y") - .set_attr_ksize({1, 3, 3, 1}) - .set_attr_padding("SAME") - .set_attr_strides({1, 2, 2, 1}); - - MAKE_RESNET50(maxpool); - std::vector inputs{data}; //,var,var1,layer1_blk1_bn1_b,var3,var4}; - std::vector outputs{}; - - graph.SetInputs(inputs).SetOutputs(outputs); - return true; -} - -#define GENERATE_CONSTANT_USE_DESC(OPNUM, desc, val) \ - uint32_t OPNUM##_size = desc.GetShape().GetShapeSize(); \ - Tensor OPNUM##_tensor; \ - OPNUM##_tensor.SetTensorDesc(desc); \ - if (desc.GetDataType() == DT_FLOAT) { \ - float *OPNUM##_data = new float[OPNUM##_size]; \ - for (int i = 0; i < (int)OPNUM##_size; i++) { \ - *(OPNUM##_data + i) = val; \ - } \ - OPNUM##_tensor.SetData((uint8_t *)OPNUM##_data, OPNUM##_size * sizeof(float)); \ - delete[] OPNUM##_data; \ - } \ - if (desc.GetDataType() == DT_INT64) { \ - int64_t *OPNUM##_data = new int64_t[OPNUM##_size]; \ - for (int i = 0; i < (int)OPNUM##_size; i++) { \ - *(OPNUM##_data + i) = val; \ - } \ - OPNUM##_tensor.SetData((uint8_t *)OPNUM##_data, OPNUM##_size * sizeof(int64_t)); \ - delete[] OPNUM##_data; \ - } \ - auto OPNUM##_constant = op::Constant().set_attr_value(OPNUM##_tensor); \ - OPNUM##_constant.update_output_desc_y(desc); - -#define GENERATE_VAR_LAYER(OPNUM, desc, input) \ - auto OPNUM##_weight = op::Variable(string(#OPNUM)); \ - OPNUM##_weight.update_output_desc_y(desc); \ - auto OPNUM##_assign = op::Assign().set_input_ref(OPNUM##_weight).set_input_value(OPNUM##_constant); \ - \ - input.push_back(OPNUM##_weight); - -#define GENERATE_VAR_LAYER_1(OPNUM, desc, var_format, input, name) \ - auto OPNUM##_weight = op::Variable(string(name)); \ - OPNUM##_weight.update_output_desc_y(desc); \ - auto OPNUM##_assign = op::Assign().set_input_ref(OPNUM##_weight).set_input_value(OPNUM##_constant); \ - \ - input.push_back(OPNUM##_weight); - -int BuildInitVarGraph(Graph &graph) { - std::vector inputs{}; - std::vector outputs{}; - - TensorDesc desc(ge::Shape({64, 3, 7, 7}), FORMAT_NCHW, DT_FLOAT); - GENERATE_CONSTANT_USE_DESC(conv2d_var, desc, 0.01); - GENERATE_VAR_LAYER(conv2d_var, desc, inputs); - - GENERATE_CONSTANT_USE_DESC(conv2d_varw1, desc, 0.01); - GENERATE_VAR_LAYER(conv2d_varw1, desc, inputs); - - TensorDesc desc1(ge::Shape({1, 64, 1, 1}), FORMAT_NCHW, DT_FLOAT); - GENERATE_CONSTANT_USE_DESC(bn_var1, desc1, 0.01); - GENERATE_VAR_LAYER(bn_var1, desc1, inputs); - GENERATE_CONSTANT_USE_DESC(bn_var2, desc1, 0.01); - GENERATE_VAR_LAYER(bn_var2, desc1, inputs); - GENERATE_CONSTANT_USE_DESC(bn_var3, desc1, 0.01); - GENERATE_VAR_LAYER(bn_var3, desc1, inputs); - GENERATE_CONSTANT_USE_DESC(bn_var4, desc1, 0.01); - GENERATE_VAR_LAYER(bn_var4, desc1, inputs); - - TensorDesc desc2(ge::Shape({2048, 1001}), FORMAT_NCHW, DT_FLOAT); - GENERATE_CONSTANT_USE_DESC(var5, desc2, 0.01); - GENERATE_VAR_LAYER(var5, desc2, inputs); - GENERATE_CONSTANT_USE_DESC(var6, desc2, 0.01); - GENERATE_VAR_LAYER(var6, desc2, inputs); - - TensorDesc desclabel(ge::Shape({1, 1001, 1, 1}), FORMAT_NCHW, DT_FLOAT); - GENERATE_CONSTANT_USE_DESC(label1, desclabel, 0.1); - GENERATE_VAR_LAYER(label1, desclabel, inputs); - - TensorDesc descmatlabel(ge::Shape({1, 1001, 1, 1}), FORMAT_NCHW, DT_FLOAT); - GENERATE_CONSTANT_USE_DESC(matvar, descmatlabel, 0.01); - GENERATE_VAR_LAYER(matvar, descmatlabel, inputs); - GENERATE_CONSTANT_USE_DESC(matvar1, descmatlabel, 0.01); - GENERATE_VAR_LAYER(matvar1, descmatlabel, inputs); - - MAKE_RESNET50_VAR(inputs); - - TensorDesc ctrl(ge::Shape({1, 1, 1, 1}), FORMAT_NCHW, DT_INT64); - - GENERATE_CONSTANT_USE_DESC(iterations_per_loop, ctrl, 100); - GENERATE_VAR_LAYER_1(iterations_per_loop, ctrl, "4D", inputs, "npu_runconfig/iterations_per_loop"); - GENERATE_CONSTANT_USE_DESC(loop_cond, ctrl, 0); - GENERATE_VAR_LAYER_1(loop_cond, ctrl, "4D", inputs, "npu_runconfig/loop_cond"); - GENERATE_CONSTANT_USE_DESC(one, ctrl, 1); - GENERATE_VAR_LAYER_1(one, ctrl, "4D", inputs, "npu_runconfig/one"); - GENERATE_CONSTANT_USE_DESC(zero, ctrl, 0); - GENERATE_VAR_LAYER_1(zero, ctrl, "4D", inputs, "npu_runconfig/zero"); - - graph.SetInputs(inputs).SetOutputs(outputs); - return 0; -} -int TestBuildGraphTest(Func fun, Graph &graph, vector &inputs, vector &outputs) { - bool graph_ret = fun(graph); - ge::Tensor shapeTensor; - TensorDesc shape_desc(ge::Shape({32, 3, 224, 224}), FORMAT_NCHW, DT_FLOAT); - uint32_t sizeshape = shape_desc.GetShape().GetShapeSize(); - printf("[test] desc size filter shape:%u\n", sizeshape); - shapeTensor.SetTensorDesc(shape_desc); - vector dataValuec; - for (int i = 0; i < sizeshape; i++) { - dataValuec.push_back(1); - } - - shapeTensor.SetData((uint8_t *)dataValuec.data(), 4 * sizeshape); - inputs.push_back(shapeTensor); - - ge::Tensor shapeTensor1; - TensorDesc shape_desc1(ge::Shape({1, 32, 1, 1}), FORMAT_NCHW, DT_FLOAT); - uint32_t sizeshape1 = shape_desc1.GetShape().GetShapeSize(); - printf("[test] desc size filter shape:%u\n", sizeshape1); - shapeTensor1.SetTensorDesc(shape_desc1); - vector dataValuec1; - for (int i = 0; i < sizeshape1; i++) { - dataValuec1.push_back(1); - } - - shapeTensor1.SetData((uint8_t *)dataValuec1.data(), 4 * sizeshape1); - - return 0; -} -int runTrainGraph(Func fun, int loopCount) { - printf("GE BBIT begin...\n"); - std::chrono::system_clock::time_point start = std::chrono::system_clock::now(); - - std::map ge_options = { - {"device_id", "0"}, {"rank_table_file", ""}, {"graphType", "1"}, {"ge.graphRunMode", "2"}}; - - std::map session_options = {{"a", "b"}, {TRAIN_FLAG, "1"}}; - - ge::Status ret; - - // init ge - ret = GEInitialize_api_new("train", "fe,plugin"); - printf("ge::GEInitialize ret:%d\n", ret); - - // init session - ge::Session session(session_options); - - int graphId_initvar = 1; - ge::Graph graph_initvar("initVarGraph"); - bool graph_ret = BuildInitVarGraph(graph_initvar); - - // session addgraph - int graphId = 0; - - // build graph - ge::Graph graph("bigGraph"); - std::vector inputs; - ge::Tensor outputTensor; - std::vector outputs; - graph_ret = TestBuildGraphTest(fun, graph, inputs, outputs); - printf("TestReluGrad ret:%d\n", graph_ret); - - ret = session.AddGraph(graphId_initvar, graph_initvar); - printf("session.AddVarGraph ret:%d\n", ret); - if (ret) return ret; - - ret = session.AddGraph(graphId, graph); - printf("session.AddGraph ret:%d\n", ret); - if (ret) return ret; - - std::vector inputs1; - std::vector outputs1; - ret = session.RunGraph(graphId_initvar, inputs1, outputs1); - - if (ret != SUCCESS) { - return ret; - } - // add loop for test of stabilty: - for (int i = 0; i < loopCount; i++) { - // session rungraph - printf("loopCount:%d\n", loopCount); - ret = session.RunGraph(graphId, inputs, outputs); - printf("session.RunGraph ret:%d\n", ret); - if (ret) return ret; - - // define 99999 as loop forever - if (loopCount == 99999) i = 0; - } - std::chrono::system_clock::time_point end = std::chrono::system_clock::now(); - auto millisecondsduration = std::chrono::duration_cast(end - start); - auto ms = millisecondsduration.count(); - std::stringstream ss; - ss << ms << "ms"; - std::string run_time = ss.str(); - printf("run time is : %s \n", run_time.c_str()); - - return 0; -} - -int main(int argc, char *argv[]) { - // add loop for test of stabilty: - int loopCount = 1; - if (argc >= 2) loopCount = atoi(argv[1]); - - Status ret = SUCCESS; - ret = runTrainGraph(resnet50, loopCount); - if (ret == SUCCESS) { - std::cout << "[train resnet50 success]" << std::endl; - } else { - std::cout << "!!! train resnet50 fail !!!" << std::endl; - } - return ret; -} diff --git a/tests/st/test_ge_st.py b/tests/st/test_ge_st.py deleted file mode 100644 index b5479cfc..00000000 --- a/tests/st/test_ge_st.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -ge st test. -""" -import pytest -import subprocess -import os - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_card -@pytest.mark.component_ge -def test_resnet50_train(): - ge_st_dir=os.environ.get('GE_ST_DIR', - '/home/jenkins/workspace/release_pkg/gate/graphengine_lib') - ge_lib_dir=os.environ.get('GRAPHENGINE_LIB', '/home/jenkins/workspace/release_pkg/gate/graphengine_lib') - - real_pythonpath=os.environ.get('REAL_PYTHONPATH') - pythonpath=os.environ.get('PYTHONPATH') - if real_pythonpath: - if pythonpath: - os.environ['PYTHONPATH']=real_pythonpath+':'+pythonpath - else: - os.environ['PYTHONPATH']=real_pythonpath - print('PYTHONPATH: '+os.environ.get('PYTHONPATH')) - - os.environ['ASCEND_OPP_PATH']='/usr/local/Ascend/opp' - os.environ['ASCEND_ENGINE_PATH']='/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel/libaicpu_engine.so:' \ - '/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel/libfe.so:' \ - '/usr/local/Ascend/fwkacllib/lib64/plugin/opskernel/librts_engine.so:'+ \ - ge_lib_dir + '/libge_local_engine.so' - print('ASCEND_OPP_PATH: '+os.environ.get('ASCEND_OPP_PATH')) - print('ASCEND_ENGINE_PATH: '+os.environ.get('ASCEND_ENGINE_PATH')) - print('LD_LIBRARY_PATH: '+os.environ.get('LD_LIBRARY_PATH')) - - cmd=ge_st_dir + '/st_resnet50_train' - print('cmd: '+cmd) - os.environ['SLOG_PRINT_TO_STDOUT']="1" - ret=subprocess.call([cmd], shell=True) - assert ret==0 - diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index e305d281..1dfd8bbc 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -182,6 +182,7 @@ set(COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/mark_same_addr_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/mark_graph_unknown_status_pass.cc" + "${GE_CODE_DIR}/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/mark_agnostic_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dimension_compute_pass.cc" "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index 68416409..5b87939f 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -306,8 +306,8 @@ class UtestLogicalStreamAllocator : public testing::Test { max_parallel_num["aicpu"] = parallel_num; Status status = AssignLogicalStreams({const1, const2, get_next, genmask1, genmask2, domask, subgraph4, subgraph5, - subgraph6, allreduce1, allreduce2, apply1, apply2}, - confs, max_parallel_num); + subgraph6, allreduce1, allreduce2, apply1, apply2}, + confs, max_parallel_num); EXPECT_EQ(status, ge::SUCCESS); EXPECT_EQ(GetStream(get_next), 0); diff --git a/tests/ut/ge/graph/load/new_op_test_utils.h b/tests/ut/ge/graph/load/new_op_test_utils.h index 325a3f1f..4cbc78ac 100644 --- a/tests/ut/ge/graph/load/new_op_test_utils.h +++ b/tests/ut/ge/graph/load/new_op_test_utils.h @@ -154,7 +154,7 @@ class OmeTestOpUtils { if (model->HasAttr(MODEL_ATTR_TASKS)) { ge::Buffer task_buffer; GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, - "Get bytes failed."); + "Get bytes failed."); std::shared_ptr task = ge::MakeShared(); GE_CHECK_NOTNULL(task); GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); diff --git a/third_party/fwkacllib/inc/aicpu/aicpu_schedule/aicpu_op_type_list.h b/third_party/fwkacllib/inc/aicpu/aicpu_schedule/aicpu_op_type_list.h new file mode 100644 index 00000000..7e0f94a8 --- /dev/null +++ b/third_party/fwkacllib/inc/aicpu/aicpu_schedule/aicpu_op_type_list.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef AICPU_OP_TYPE_LIST_H_ +#define AICPU_OP_TYPE_LIST_H_ + +enum OpKernelType { + TF_KERNEL, + CPU_KERNEL +}; + +enum ReturnCode { + OP_TYPE_NOT_SUPPORT, + FORMAT_NOT_SUPPORT, + DTYPE_NOT_SUPPORT +}; + +#pragma pack(push, 1) +//One byte alignment +struct SysOpInfo { + uint64_t opLen; + uint64_t opType; + OpKernelType kernelsType; +}; + +struct OpParamInfo { + uint64_t num; + uint64_t dtypeList; + uint64_t formatList; +}; + +struct SysOpCheckInfo { + uint64_t opListNum; + uint64_t offSetLen; + uint64_t sysOpInfoList; + uint64_t opParamInfoList; +}; + +struct SysOpCheckResp { + uint64_t opListNum; + bool isWithoutJson; + uint64_t returnCodeList; + uint64_t sysOpInfoList; + uint64_t opParamInfoList; +}; +#pragma pack(pop) +#endif // AICPU_OP_TYPE_LIST_H_ diff --git a/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h b/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h index a5f43be9..8c0c1847 100644 --- a/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h +++ b/third_party/fwkacllib/inc/cce/aicpu_engine_struct.h @@ -33,18 +33,22 @@ typedef enum { FMK_KERNEL_TYPE_RESERVED } FwkkernelType_t; +#pragma pack(push, 1) typedef struct { uint32_t fwkKernelType; // FwkkernelType_t union { ::aicpu::FWKAdapter::FWKOperateParam fwk_kernel; } fwkKernelBase; -} __attribute__((packed)) STR_FWK_OP_KERNEL; +} STR_FWK_OP_KERNEL; +#pragma pack(pop) +#pragma pack(push, 1) struct SessionInfo { uint64_t sessionId; uint64_t kernelId; bool sessFlag; -} __attribute__((packed)); +}; +#pragma pack(pop) #ifdef __cplusplus } diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index 79d94023..50b39d91 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -70,6 +70,7 @@ enum FWKExtUpdateAddrType { FWK_ADPT_UPDATE_INPUT_OUTPUT }; +#pragma pack(push, 1) // API Parameter Structure struct StrFWKKernel { FWKOperateType opType; @@ -89,31 +90,39 @@ struct StrFWKKernel { uint64_t extInfoLen; // extend info total length uint64_t extInfoAddr; // extend info addr, ExtInfo structure -} __attribute__((packed)); +}; +#pragma pack(pop) typedef StrFWKKernel FWKOperateParam; // Extent info ShapeAndType const uint32_t kMaxShapeDims = 8; +#pragma pack(push, 1) struct ShapeAndType { int32_t type; int64_t dims[kMaxShapeDims]; -} __attribute__((packed)); +}; +#pragma pack(pop) // Extend info structure for extInfoAddr const uint32_t kExtInfoHeadSize = 8; + +#pragma pack(push, 1) struct ExtInfo { int32_t infoType; // extend type uint32_t infoLen; // length for infoMsg char infoMsg[0]; // extend value -} __attribute__((packed)); +}; +#pragma pack(pop) +#pragma pack(push, 1) struct ResultSummary { uint64_t shape_data_ptr; // shape data addr, need convert to void* uint64_t shape_data_size; // num of dims uint64_t raw_data_ptr; // raw data addr, need convert to void* uint64_t raw_data_size; // size of raw data -} __attribute__((packed)); +}; +#pragma pack(pop) } // end namespace FWKAdapter } // namespace aicpu diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 8194097e..9facd20c 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -22,7 +22,8 @@ #ifndef HCCL_BASE_H_ #define HCCL_BASE_H_ - +#include +#include #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -95,6 +96,33 @@ typedef void *rtStream_t; */ typedef void *rtModel_t; +struct HcomOperation { + std::string hcclType; + void *inputPtr; + void *outputPtr; + u64 count; + HcclDataType dataType; + HcclReduceOp opType; + u32 root; + + HcomOperation() + { + inputPtr = nullptr; + outputPtr = nullptr; + count = 0; + dataType = HCCL_DATA_TYPE_RESERVED; + opType = HCCL_REDUCE_RESERVED; + root = 0; + } +}; + +struct HcomRemoteAccessAddrInfo { + u32 remotetRankID; + u64 remoteAddr; // host embedding table address + u64 localAddr; // device HBM address + u64 length; // Memory Length in Bytes +}; + #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index de140b4b..e491d43f 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -24,6 +24,8 @@ #include #include +#include +#include #ifdef __cplusplus extern "C" { @@ -41,6 +43,15 @@ extern "C" { HcclResult hcom_get_rank_size(const char *group, u32 *rankSize); /** + * @brief Get the rank number in the group. + * + * @param group A string identifying the group name. + * @param rankSize A pointer identifying the rank number. + * @return HcclResult + */ +HcclResult HcomGetRankSize(const char *group, u32 *rankSize); + +/** * @brief Get the rank number of this rank's server within the group. * * @param group A string identifying the group name. @@ -50,6 +61,15 @@ HcclResult hcom_get_rank_size(const char *group, u32 *rankSize); HcclResult hcom_get_local_rank_size(const char *group, u32 *localRankSize); /** + * @brief Get the rank number of this rank's server within the group. + * + * @param group A string identifying the group name. + * @param localRankSize A pointer identifying the rank number. + * @return HcclResult + */ +HcclResult HcomGetLocalRankSize(const char *group, u32 *localRankSize); + +/** * @brief Get the rank id of this rank. * * @param group A string identifying the group name. @@ -59,6 +79,15 @@ HcclResult hcom_get_local_rank_size(const char *group, u32 *localRankSize); HcclResult hcom_get_rank_id(const char *group, u32 *rankId); /** + * @brief Get the rank id of this rank. + * + * @param group A string identifying the group name. + * @param rankId A pointer identifying the rank id. + * @return HcclResult + */ +HcclResult HcomGetRankId(const char *group, u32 *rankId); + +/** * @brief Get the local rank id of this rank's server within the group. * * @param group A string identifying the group name. @@ -68,6 +97,15 @@ HcclResult hcom_get_rank_id(const char *group, u32 *rankId); HcclResult hcom_get_local_rank_id(const char *group, u32 *localRankId); /** + * @brief Get the local rank id of this rank's server within the group. + * + * @param group A string identifying the group name. + * @param localRankId A pointer identifying the local rank id. + * @return HcclResult + */ +HcclResult HcomGetLocalRankId(const char *group, u32 *localRankId); + +/** * @brief Get the world rank id according to the group rank id. * * @param group A string identifying the group name. @@ -78,6 +116,16 @@ HcclResult hcom_get_local_rank_id(const char *group, u32 *localRankId); HcclResult hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, u32 *worldRank); /** + * @brief Get the world rank id according to the group rank id. + * + * @param group A string identifying the group name. + * @param groupRank An integer(u32) identifying the group rank id. + * @param worldRank A pointer identifying the world rank id. + * @return HcclResult + */ +HcclResult HcomGetWorldRankFromGroupRank(const char *group, u32 groupRank, u32 *worldRank); + +/** * @brief Get the group rank id according to the world rank id. * * @param worldRank An integer(u32) identifying the world rank id. @@ -88,6 +136,16 @@ HcclResult hcom_get_world_rank_from_group_rank(const char *group, u32 groupRank, HcclResult hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, u32 *groupRank); /** + * @brief Get the group rank id according to the world rank id. + * + * @param worldRank An integer(u32) identifying the world rank id. + * @param group A string identifying the group name. + * @param groupRank A pointer identifying the group rank id. + * @return HcclResult + */ +HcclResult HcomGetGroupRankFromWorldRank(u32 worldRank, const char *group, u32 *groupRank); + +/** * @brief Create group. * * @param group A string identifying the group name. @@ -98,6 +156,16 @@ HcclResult hcom_get_group_rank_from_world_rank(u32 worldRank, const char *group, HcclResult hcom_create_group(const char *group, u32 rankNum, u32 *rankIds); /** + * @brief Create group. + * + * @param group A string identifying the group name. + * @param rankNum An integer(u32) identifying the number of ranks in the group. + * @param rankIds A list identifying the ranks in the group. + * @return HcclResult + */ +HcclResult HcomCreateGroup(const char *group, u32 rankNum, u32 *rankIds); + +/** * @brief Destroy group * * @param group A string identifying the group name. @@ -106,6 +174,14 @@ HcclResult hcom_create_group(const char *group, u32 rankNum, u32 *rankIds); HcclResult hcom_destroy_group(const char *group); /** + * @brief Destroy group + * + * @param group A string identifying the group name. + * @return HcclResult + */ +HcclResult HcomDestroyGroup(const char *group); + +/** * @brief Set the gradient split strategy with in the group, according to gradient index. * * @param group A string identifying the group name. @@ -116,6 +192,16 @@ HcclResult hcom_destroy_group(const char *group); extern HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmentNum, const u32 *IdxList); /** + * @brief Set the gradient split strategy with in the group, according to gradient index. + * + * @param group A string identifying the group name. + * @param segmentNum An integer(u32) identifying the segments number of gradients. + * @param IdxList A list identifying the index of end gradient in each segment. + * @return HcclResult + */ +extern HcclResult HcomSetGradFusionByIndex(const char *group, u32 segmentNum, const u32 *IdxList); + +/** * @brief Set the gradient split strategy with in the group, according to gradient data size. * * @param group A string identifying the group name. @@ -126,6 +212,16 @@ extern HcclResult hcom_set_split_strategy_by_index(const char *group, u32 segmen extern HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segmentNum, const float *sizeList); /** + * @brief Set the gradient split strategy with in the group, according to gradient data size. + * + * @param group A string identifying the group name. + * @param segmentNum An integer(u32) identifying the segments number of gradients. + * @param sizeList A list identifying the percent of each segment. + * @return HcclResult + */ +extern HcclResult HcomSetGradFusionBySize(const char *group, u32 segmentNum, const float *sizeList); + +/** * @brief Register memories and init resources for remote access. * * @param addrList memory addresses for remote access. @@ -134,6 +230,25 @@ extern HcclResult hcom_set_split_strategy_by_size(const char *group, u32 segment */ extern HcclResult hcom_remote_access_mem_register(const MemRegisterAddr* addrList, u32 count); +/** + * @brief Register memories and init resources for remote access. + * + * @param addrList memory addresses for remote access. + * @param count number of remote memory addresses. + * @return HcclResult + */ +extern HcclResult HcomRegRemoteAccessMem(const MemRegisterAddr* addrList, u32 count); + +HcclResult HcomExecInitialize(); + +HcclResult HcomExecFinalize(); + +HcclResult HcomExecEnqueueOperation(HcomOperation opInfo, std::function callback); + +HcclResult HcomExecEnqueueRemoteAccess(const std::string& remoteAccessType, + const std::vector& addrInfos, + std::function callback); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h index c74f95ac..66638bbb 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h @@ -50,7 +50,7 @@ typedef int (*mmFilter)(const mmDirent *entry); typedef int (*mmFilter2)(const mmDirent2 *entry); typedef int (*mmSort)(const mmDirent **a, const mmDirent **b); typedef int (*mmSort2)(const mmDirent2 **a, const mmDirent2 **b); -typedef size_t mmSize_t; +typedef size_t mmSize_t; //lint !e410 !e1051 typedef off_t mmOfft_t; typedef pid_t mmPid_t; typedef long MM_LONG; @@ -215,6 +215,10 @@ typedef struct { #define S_IWRITE S_IWUSR #endif +#define mm_no_argument no_argument +#define mm_required_argument required_argument +#define mm_optional_argument optional_argument + #define M_FILE_RDONLY O_RDONLY #define M_FILE_WRONLY O_WRONLY #define M_FILE_RDWR O_RDWR @@ -412,8 +416,12 @@ MMPA_FUNC_VISIBILITY VOID mmClosePipe(mmPipeHandle pipe[], UINT32 pipeCount); // Poll related interface MMPA_FUNC_VISIBILITY mmCompletionHandle mmCreateCompletionPort(); MMPA_FUNC_VISIBILITY VOID mmCloseCompletionPort(mmCompletionHandle handle); -MMPA_FUNC_VISIBILITY INT32 mmPoll(mmPollfd *fds, INT32 fdCount, INT32 timeout, mmCompletionHandle handleIOCP, - pmmPollData polledData, mmPollBack pollBack); +MMPA_FUNC_VISIBILITY INT32 mmPoll(mmPollfd *fds, + INT32 fdCount, + INT32 timeout, + mmCompletionHandle handleIOCP, + pmmPollData polledData, + mmPollBack pollBack); MMPA_FUNC_VISIBILITY INT32 mmGetErrorCode(); MMPA_FUNC_VISIBILITY CHAR *mmGetErrorFormatMessage(mmErrorMsg errnum, CHAR *buf, mmSize size); MMPA_FUNC_VISIBILITY INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone); diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h index a5a22b4f..aa58e722 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h @@ -237,6 +237,11 @@ typedef struct { } mmThreadAttr; typedef VOID (*mmPf)(VOID); + +#define mm_no_argument 0 +#define mm_required_argument 1 +#define mm_optional_argument 2 + #define M_FILE_RDONLY GENERIC_READ #define M_FILE_WRONLY GENERIC_WRITE #define M_FILE_RDWR (GENERIC_READ | GENERIC_WRITE) diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 85f16cc5..aa8263f9 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -18,6 +18,7 @@ #define __CCE_RUNTIME_BASE_H__ #include +#include "toolchain/prof_callback.h" #if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) extern "C" { @@ -86,10 +87,20 @@ typedef struct rtExceptionInfo { uint32_t deviceid; } rtExceptionInfo; +typedef struct rtTaskFailInfo { + uint32_t taskid; + uint32_t streamid; + uint32_t tid; + uint32_t deviceid; + uint32_t retcode; +} rtTaskFailInfo; + typedef void (*rtErrorCallback)(rtExceptionType); typedef void (*rtTaskFailCallback)(rtExceptionInfo *exceptionInfo); +typedef void (*rtTaskFailCallbackByModule)(rtTaskFailInfo *exceptionInfo); + typedef void (*rtDeviceStateCallback)(uint32_t devId, bool isOpen); /** @@ -147,6 +158,12 @@ RTS_API rtError_t rtProfilerStop(uint64_t profConfig, int32_t numsDev, uint32_t* RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stream); /** + * @ingroup profiling_base + * @brief ts set profiling reporter callback. + */ +RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback); + +/** * @ingroup dvrt_base * @brief Returns the last error from a runtime call. */ @@ -186,6 +203,16 @@ RTS_API rtError_t rtRegDeviceStateCallback(const char *regName, rtDeviceStateCal /** * @ingroup dvrt_base + * @brief register callback for fail task + * @param [in] uniName unique register name, can't be null + * @param [in] callback fail task callback function + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFailCallbackByModule callback); + +/** + * @ingroup dvrt_base * @brief notify handle. */ typedef void *rtNotify_t; diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index c471f128..c35a1278 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -123,14 +123,6 @@ typedef struct tagRtPlatformConfig { uint32_t platformConfig; } rtPlatformConfig /** * @ingroup - * @brief get platform - * @param [in] platForm - * @return platForm - */ -RTS_API rtError_t rtGetPlatformConfig(rtPlatformConfig_t *platForm); - -/** - * @ingroup * @brief get AI core count * @param [in] aiCoreCnt * @return aiCoreCnt @@ -169,13 +161,6 @@ RTS_API rtError_t rtGetAiCoreMemoryRates(rtAiCoreMemoryRates_t *aiCoreMemoryRate */ RTS_API rtError_t rtGetMemoryConfig(rtMemoryConfig_t *memoryConfig); -/** - * @ingroup - * @brief set platform in gen ctx - * @param [in] platForm - * @return RT_ERROR_NONE for ok, errno for failed - */ -RTS_API rtError_t rtSetPlatformType(rtPlatformType_t platformType); /** * @ingroup diff --git a/third_party/fwkacllib/inc/tdt/tsd_client.h b/third_party/fwkacllib/inc/tdt/tsd_client.h index 6066a12e..665c8b82 100644 --- a/third_party/fwkacllib/inc/tdt/tsd_client.h +++ b/third_party/fwkacllib/inc/tdt/tsd_client.h @@ -23,6 +23,7 @@ #include #include "tdt/status.h" #include "tdt/data_common.h" +#include "toolchain/prof_callback.h" #ifdef __cplusplus extern "C" { @@ -37,7 +38,7 @@ extern "C" { * Used for the Framework process to communicate with the TSDDaemon process, * and notify TSD to complete the initialization of other processes * -* @param phyDeviceId [IN] type #unsigned int. Physical device ID +* @param logicDeviceId [IN] type #unsigned int. Logic device ID * @param rankSize [IN] type #unsigned int. The rankSize of the training. * The default value is 1. When rankSize is greater than 1, * HCCP will be pulled to perform set communication related operations. @@ -49,7 +50,7 @@ extern "C" { * @li tsd_client.h: Header file where the interface declaration is located. * @li data_common.h: Header file where 'TDT_StatusT' defined */ -TDT_LIB_EXPORT TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t rankSize); +TDT_LIB_EXPORT TDT_StatusT TsdOpen(const uint32_t logicDeviceId, const uint32_t rankSize); /** * @ingroup Close @@ -67,7 +68,7 @@ TDT_LIB_EXPORT TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t ra * @li tsd_client.h: Header file where the interface declaration is located. * @li data_common.h: Header file where 'TDT_StatusT' defined */ -TDT_LIB_EXPORT TDT_StatusT TsdClose(const uint32_t phyDeviceId); +TDT_LIB_EXPORT TDT_StatusT TsdClose(const uint32_t logicDeviceId); /** * @ingroup UpdateProfilingMode @@ -85,7 +86,26 @@ TDT_LIB_EXPORT TDT_StatusT TsdClose(const uint32_t phyDeviceId); * @li tsd_client.h: Header file where the interface declaration is located. * @li data_common.h: Header file where 'TDT_StatusT' defined */ -TDT_LIB_EXPORT TDT_StatusT UpdateProfilingMode(const uint32_t phyDeviceId, const uint32_t flag); +TDT_LIB_EXPORT TDT_StatusT UpdateProfilingMode(const uint32_t logicDeviceId, const uint32_t flag); + +/** +* @ingroup TsdSetMsprofReporterCallback +* @brief 用于推理场景下设置aicpu的profilng的callback函数 +* +* @par Function +* 设置offline模式下aicpu_sd进程的profiling的callback函数 +* +* @param callback [IN] type #MsprofReporterCallback. 回调函数 +* @retval TDT_OK Success +* @retval OtherValues Failure +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tsd_client.h: Header file where the interface declaration is located. +* @li data_common.h: Header file where 'TDT_StatusT' defined +* @li prof_callback.h: Headerfile where 'MsprofReporterCallback' defined +*/ +TDT_LIB_EXPORT TDT_StatusT TsdSetMsprofReporterCallback(MsprofReporterCallback callback); /** * @ingroup CreateCmdParameterObj diff --git a/third_party/fwkacllib/inc/toolchain/prof_acl_api.h b/third_party/fwkacllib/inc/toolchain/prof_acl_api.h index 430ed14d..efb37cfb 100644 --- a/third_party/fwkacllib/inc/toolchain/prof_acl_api.h +++ b/third_party/fwkacllib/inc/toolchain/prof_acl_api.h @@ -17,380 +17,76 @@ #ifndef MSPROFILER_API_PROF_ACL_API_H_ #define MSPROFILER_API_PROF_ACL_API_H_ -#define MSVP_MAX_DEV_NUM 64 -#ifndef OS_TYPE -#define OS_TYPE 0 -#endif // OS_TYPE - - -#if (OS_TYPE != LINUX) -#define MSVP_PROF_API __declspec(dllexport) -#else -#define MSVP_PROF_API __attribute__((visibility("default"))) -#endif - // DataTypeConfig -#define PROF_ACL_API 0x0001 -#define PROF_TASK_TIME 0x0002 -#define PROF_AICORE_METRICS 0x0004 -#define PROF_AICPU_TRACE 0x0008 -#define PROF_MODEL_EXECUTE 0x0010 -#define PROF_RUNTIME_API 0x0020 -#define PROF_RUNTIME_TRACE 0x0040 -#define PROF_SCHEDULE_TIMELINE 0x0080 -#define PROF_SCHEDULE_TRACE 0x0100 -#define PROF_AIVECTORCORE_METRICS 0x0200 -#define PROF_SUBTASK_TIME 0x0400 - -#define PROF_TRAINING_TRACE 0x0800 -#define PROF_HCCL_TRACE 0x1000 -#define PROF_DATA_PROCESS 0x2000 -#define PROF_TASK_TRACE 0x3842 +#define PROF_ACL_API 0x00000001 +#define PROF_TASK_TIME 0x00000002 +#define PROF_AICORE_METRICS 0x00000004 +#define PROF_AICPU_TRACE 0x00000008 +#define PROF_MODEL_EXECUTE 0x00000010 +#define PROF_RUNTIME_API 0x00000020 +#define PROF_RUNTIME_TRACE 0x00000040 +#define PROF_SCHEDULE_TIMELINE 0x00000080 +#define PROF_SCHEDULE_TRACE 0x00000100 +#define PROF_AIVECTORCORE_METRICS 0x00000200 +#define PROF_SUBTASK_TIME 0x00000400 + +#define PROF_TRAINING_TRACE 0x00000800 +#define PROF_HCCL_TRACE 0x00001000 + +#define PROF_TASK_TRACE 0x00001852 + +// system profilinig switch +#define PROF_CPU 0x00010000 +#define PROF_HARDWARE_MEMORY 0x00020000 +#define PROF_IO 0x00040000 +#define PROF_INTER_CONNECTION 0x00080000 +#define PROF_DVPP 0x00100000 +#define PROF_SYS_AICORE_SAMPLE 0x00200000 +#define PROF_AIVECTORCORE_SAMPLE 0x00400000 #define PROF_MODEL_LOAD 0x8000000000000000 // DataTypeConfig MASK -#define PROF_ACL_API_MASK 0x0001 -#define PROF_TASK_TIME_MASK 0x0002 -#define PROF_AICORE_METRICS_MASK 0x0004 -#define PROF_AICPU_TRACE_MASK 0x0008 -#define PROF_MODEL_EXECUTE_MASK 0x0010 -#define PROF_RUNTIME_API_MASK 0x0020 -#define PROF_RUNTIME_TRACE_MASK 0x0040 -#define PROF_SCHEDULE_TIMELINE_MASK 0x0080 -#define PROF_SCHEDULE_TRACE_MASK 0x0100 -#define PROF_AIVECTORCORE_METRICS_MASK 0x0200 -#define PROF_SUBTASK_TIME_MASK 0x0400 - -#define PROF_TRAINING_TRACE_MASK 0x0800 -#define PROF_HCCL_TRACE_MASK 0x1000 -#define PROF_DATA_PROCESS_MASK 0x2000 +#define PROF_ACL_API_MASK 0x00000001 +#define PROF_TASK_TIME_MASK 0x00000002 +#define PROF_AICORE_METRICS_MASK 0x00000004 +#define PROF_AICPU_TRACE_MASK 0x00000008 +#define PROF_MODEL_EXECUTE_MASK 0x00000010 +#define PROF_RUNTIME_API_MASK 0x00000020 +#define PROF_RUNTIME_TRACE_MASK 0x00000040 +#define PROF_SCHEDULE_TIMELINE_MASK 0x00000080 +#define PROF_SCHEDULE_TRACE_MASK 0x00000100 +#define PROF_AIVECTORCORE_METRICS_MASK 0x00000200 +#define PROF_SUBTASK_TIME_MASK 0x00000400 + +#define PROF_TRAINING_TRACE_MASK 0x00000800 +#define PROF_HCCL_TRACE_MASK 0x00001000 + +// system profilinig mask +#define PROF_CPU_MASK 0x00010000 +#define PROF_HARDWARE_MEMORY_MASK 0x00020000 +#define PROF_IO_MASK 0x00040000 +#define PROF_INTER_CONNECTION_MASK 0x00080000 +#define PROF_DVPP_MASK 0x00100000 +#define PROF_SYS_AICORE_SAMPLE_MASK 0x00200000 +#define PROF_AIVECTORCORE_SAMPLE_MASK 0x00400000 #define PROF_MODEL_LOAD_MASK 0x8000000000000000 #include -#include - -/** - * @name ProrErrorCode - * @brief error code enum of prof_acl_apis - */ -enum ProfErrorCode { - PROF_ERROR_NONE = 0, // ok - PROF_ERROR_PARAM_INVALID, // param invalid, for example nullptr - PROF_ERROR_REPEAT_INIT, // profiling has already been inited - PROF_ERROR_CONFIG_INVALID, // config invalid, for example invalid json string - PROF_ERROR_DIR_NO_ACCESS, // dir is not accessable - PROF_ERROR_FAILURE, // failed to init or start profiling - PROF_ERROR_NOT_INITED, // profiling has not been inited - PROF_ERROR_DEVICE_INVALID, // device id invalid - PROF_ERROR_UNSUPPORTED, // unsupported data type or ai core metrics - PROF_ERROR_REPEAT_START, // profiilng has already been started - PROF_ERROR_NOT_STARTED, // profiling has not been started - PROF_ERROR_REPEAT_SUBSCRIBE, // same model id has already been subscribed - PROF_ERROR_MODEL_ID_INVALID, // model id does not exist or has not been subscribed - PROF_ERROR_API_CONFLICT, // prof ctrl api mode conflicts with subscribe mode -}; - -/** - * @brief transfer profiling config in acl.json to sample config - * @param aclCfg [IN] profiling json string from acl.json as {"switch":"on", "result_path":"/home",...} - * @param sampleCfg [OUT] json string for GE as {"startCfg":[{"deviceID":"all","jobID":"1234",...}]} - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfAclCfgToSampleCfg(const std::string &aclCfg, std::string &sampleCfg); - -/** - * @name ProfInit - * @brief init profiling - * @param profInitCfg [IN] config of init profiling of json format - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfInit(const std::string &profInitCfg); - -/** - * @name ProfAicoreMetrics - * @brief aicore metrics enum - */ -enum ProfAicoreMetrics { - PROF_AICORE_ARITHMATIC_THROUGHPUT = 0, - PROF_AICORE_PIPELINE = 1, - PROF_AICORE_SYNCHRONIZATION = 2, - PROF_AICORE_MEMORY = 3, - PROF_AICORE_INTERNAL_MEMORY = 4, - PROF_AICORE_STALL = 5, - PROF_AICORE_METRICS_COUNT, - PROF_AICORE_NONE = 0xff, -}; - -/** - * @name ProfConfig - * @brief struct of ProfStart - */ -struct ProfConfig { - uint32_t devNums; // length of device id list - uint32_t devIdList[MSVP_MAX_DEV_NUM]; // physical device id list - ProfAicoreMetrics aicoreMetrics; // aicore metric - uint64_t dataTypeConfig; // data type to start profiling -}; - -/** - * @name ProfStartProfiling - * @brief start profiling - * @param profStartCfg [IN] config to start profiling - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfStartProfiling(const ProfConfig *profStartCfg); - -/** - * @name ProfStopProfiling - * @brief stop profiling - * @param profStopCfg [IN] config to stop profiling - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfStopProfiling(const ProfConfig *profStopCfg); - -/** - * @name ProfFinalize - * @brief finalize profiling task - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfFinalize(); - -/** - * @name ProfGetDataTypeConfig - * @brief get dataTypeConfig started with of one device - * @param deviceId [IN] deviceId to get dataTypeConfig - * @param dataTypeConfig [OUT] result get - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetDataTypeConfig(uint32_t deviceId, uint64_t &dataTypeConfig); namespace Msprofiler { namespace Api { /** - * @brief transfer profiling config in acl.json to sample config - * @param aclCfg [IN] profiling json string from acl.json as {"switch":"on", "result_path":"/home",...} - * @param sampleCfg [OUT] json string for GE as {"startCfg":[{"deviceID":"all","jobID":"1234",...}]} - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfAclCfgToSampleCfg(const std::string &aclCfg, std::string &sampleCfg); - -/** - * @name ProfInit - * @brief init profiling - * @param profInitCfg [IN] config of init profiling of json format - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfInit(const std::string &profInitCfg); - -/** - * @name ProfStartProfiling - * @brief start profiling - * @param profStartCfg [IN] config to start profiling - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfStartProfiling(const ProfConfig *profStartCfg); - -/** - * @name ProfStopProfiling - * @brief stop profiling - * @param profStopCfg [IN] config to stop profiling - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfStopProfiling(const ProfConfig *profStopCfg); - -/** - * @name ProfFinalize - * @brief finalize profiling task - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfFinalize(); - -/** - * @name ProfGetDataTypeConfig - * @brief get dataTypeConfig started with of one device - * @param deviceId [IN] deviceId to get dataTypeConfig - * @param dataTypeConfig [OUT] result get - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetDataTypeConfig(uint32_t deviceId, uint64_t &dataTypeConfig); - -/** - * @name WorkMode - * @brief profiling api work mode - */ -enum WorkMode { - WORK_MODE_OFF, // profiling not at work - WORK_MODE_API_CTRL, // profiling work on api ctrl mode, (ProfInit) - WORK_MODE_SUBSCRIBE, // profiling work on subscribe mode -}; - -/** - * @name ProfGetApiWorkMode - * @brief get profiling api work mode - * @return WorkMode - */ -MSVP_PROF_API WorkMode ProfGetApiWorkMode(); - -/** - * @name ProfSubscribeConfig - * @brief config of subscribe api - */ -struct ProfSubscribeConfig { - bool timeInfo; // subscribe op time - ProfAicoreMetrics aicoreMetrics; // subscribe ai core metrics - void* fd; // pipe fd -}; - -/** - * @name ProfGetDataTypeConfig - * @brief get DataTypeConfig of subscribe - * @param profSubscribeConfig [IN] config to subscribe data - * @return DataTypeConfig - */ -MSVP_PROF_API uint64_t ProfGetDataTypeConfig(const ProfSubscribeConfig *profSubscribeConfig); - -/** - * @name ProfModelSubscribe - * @brief subscribe data of one model id - * @param modelId [IN] model id to subscribe data - * @param devId [IN] device id of model - * @param profSubscribeConfig [IN] config to subscribe data - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfModelSubscribe(uint32_t modelId, uint32_t devId, - const ProfSubscribeConfig *profSubscribeConfig); - -/** - * @name ProfIsModelSubscribed - * @brief check if a model id is subscribed - * @param modeiId [IN] modei id to check - * @return true: subscribed, false: not - */ -MSVP_PROF_API bool ProfIsModelSubscribed(uint32_t modelId); - -/** - * @name ProfModelUnSubscribe - * @brief unsubscribe a model id - * @param modeiId [IN] modei id to unsubscribe - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfModelUnSubscribe(uint32_t modelId); - -/** - * @name ProfGetOpDescSize - * @brief get profiling data struct size - * @param opDescSize [OUT] bytes of profiling subscribe data struct - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetOpDescSize(uint32_t *opDescSize); - -/** - * @name ProfGetOpNum - * @brief get how many op data there are in data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param opNum [OUT] number of op in data - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetOpNum(const void *data, uint32_t len, uint32_t *opNum); - -/** - * @name ProfGetModelId - * @brief get model id of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return model id - */ -MSVP_PROF_API uint32_t ProfGetModelId(const void *data, uint32_t len, uint32_t index); - -/** - * @name ProfGetOpType - * @brief get op type of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param opType [OUT] op type buffer - * @param opTypeLen [IN] buffer size of param opType - * @param index [IN] index of part(op) - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetOpType(const void *data, uint32_t len, char *opType, uint32_t opTypeLen, uint32_t index); - -/** - * @name ProfGetOpName - * @brief get op name of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param opType [OUT] op name buffer - * @param opTypeLen [IN] buffer size of param opName - * @param index [IN] index of part(op) - * @return ProfErrorCode - */ -MSVP_PROF_API int32_t ProfGetOpName(const void *data, uint32_t len, char *opName, uint32_t opNameLen, uint32_t index); - -/** - * @name ProfGetOpStart - * @brief get op start timestamp of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return op start timestamp (us) - */ -MSVP_PROF_API uint64_t ProfGetOpStart(const void *data, uint32_t len, uint32_t index); - -/** - * @name ProfGetOpEnd - * @brief get op end timestamp of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return op end timestamp (us) - */ -MSVP_PROF_API uint64_t ProfGetOpEnd(const void *data, uint32_t len, uint32_t index); - -/** - * @name ProfGetOpDuration - * @brief get op duration of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return op duration (us) - */ -MSVP_PROF_API uint64_t ProfGetOpDuration(const void *data, uint32_t len, uint32_t index); - -/** * @name ProfGetOpExecutionTime * @brief get op execution time of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length + * @param data [IN] data read from pipe + * @param len [IN] data length * @param index [IN] index of part(op) * @return op execution time (us) */ -MSVP_PROF_API uint64_t ProfGetOpExecutionTime(const void *data, uint32_t len, uint32_t index); - -/** - * @name ProfGetOpCubeOps - * @brief get op cube fops of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return op cube fops - */ -MSVP_PROF_API uint64_t ProfGetOpCubeOps(const void *data, uint32_t len, uint32_t index); - -/** - * @name ProfGetOpVectorOps - * @brief get op vector fops of specific part of data - * @param data [IN] data read from pipe - * @param len [IN] data length - * @param index [IN] index of part(op) - * @return op vector fops - */ -MSVP_PROF_API uint64_t ProfGetOpVectorOps(const void *data, uint32_t len, uint32_t index); - -} // namespace Api -} // namespace Msprofiler +uint64_t ProfGetOpExecutionTime(const void *data, uint32_t len, uint32_t index); +} +} #endif // MSPROFILER_API_PROF_ACL_API_H_ diff --git a/third_party/fwkacllib/inc/toolchain/prof_callback.h b/third_party/fwkacllib/inc/toolchain/prof_callback.h new file mode 100644 index 00000000..1299ae59 --- /dev/null +++ b/third_party/fwkacllib/inc/toolchain/prof_callback.h @@ -0,0 +1,132 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MSPROFILER_PROF_CALLBACK_H_ +#define MSPROFILER_PROF_CALLBACK_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + + +#include "stddef.h" +#include "stdint.h" + +/** + * @name MsprofErrorCode + * @brief error code + */ +enum MsprofErrorCode { + MSPROF_ERROR_NONE = 0, + MSPROF_ERROR_MEM_NOT_ENOUGH, + MSPROF_ERROR_GET_ENV, + MSPROF_ERROR_CONFIG_INVALID, + MSPROF_ERROR_ACL_JSON_OFF, + MSPROF_ERROR, +}; + +#define MSPROF_ENGINE_MAX_TAG_LEN (31) + +/** + * @name ReporterData + * @brief struct of data to report + */ +struct ReporterData { + char tag[MSPROF_ENGINE_MAX_TAG_LEN + 1]; // the sub-type of the module, data with different tag will be writen + int deviceId; // the index of device + size_t dataLen; // the length of send data + unsigned char *data; // the data content +}; + +/** + * @name MsprofReporterModuleId + * @brief module id of data to report + */ +enum MsprofReporterModuleId { + MSPROF_MODULE_DATA_PREPROCESS = 0, // DATA_PREPROCESS + MSPROF_MODULE_HCCL, // HCCL + MSPROF_MODULE_ACL, // AclModule + MSPROF_MODULE_FRAMEWORK, // Framework + MSPROF_MODULE_RUNTIME // runtime +}; + +/** + * @name MsprofReporterCallbackType + * @brief reporter callback request type + */ +enum MsprofReporterCallbackType { + MSPROF_REPORTER_REPORT = 0, // report data + MSPROF_REPORTER_INIT, // init reporter + MSPROF_REPORTER_UNINIT, // uninit reporter +}; + +/** + * @name MsprofReporterCallback + * @brief callback to start reporter/stop reporter/report date + * @param moduleId [IN] enum MsprofReporterModuleId + * @param type [IN] enum MsprofReporterCallbackType + * @param data [IN] callback data (nullptr on INTI/UNINIT) + * @param len [IN] callback data size (0 on INIT/UNINIT) + * @return enum MsprofErrorCode + */ +typedef int32_t (*MsprofReporterCallback)(uint32_t moduleId, uint32_t type, void *data, uint32_t len); + + +#define MSPROF_OPTIONS_DEF_LEN_MAX (2048) + +/** + * @name MsprofGeOptions + * @brief struct of MSPROF_CTRL_INIT_GE_OPTIONS + */ +struct MsprofGeOptions { + char jobId[MSPROF_OPTIONS_DEF_LEN_MAX]; + char options[MSPROF_OPTIONS_DEF_LEN_MAX]; +}; + +/** + * @name MsprofCtrlCallbackType + * @brief ctrl callback request type + */ +enum MsprofCtrlCallbackType { + MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env + MSPROF_CTRL_INIT_ACL_JSON, // start profiling with acl.json + MSPROF_CTRL_INIT_GE_OPTIONS, // start profiling with ge env and options + MSPROF_CTRL_FINALIZE // stop profiling +}; + +/** + * @name MsprofCtrlCallback + * @brief callback to start/stop profiling + * @param type [IN] enum MsprofCtrlCallbackType + * @param data [IN] callback data + * @param len [IN] callback data size + * @return enum MsprofErrorCode + */ +typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len); + +/** + * @name MsprofSetDeviceCallback + * @brief callback to notify set/reset device + * @param devId [IN] device id + * @param isOpenDevice [IN] true: set device, false: reset device + */ +typedef void (*MsprofSetDeviceCallback)(uint32_t devId, bool isOpenDevice); + +#ifdef __cplusplus +} +#endif + +#endif // MSPROFILER_PROF_CALLBACK_H_ diff --git a/third_party/fwkacllib/inc/toolchain/prof_reporter.h b/third_party/fwkacllib/inc/toolchain/prof_reporter.h index 949011d3..ff91351b 100644 --- a/third_party/fwkacllib/inc/toolchain/prof_reporter.h +++ b/third_party/fwkacllib/inc/toolchain/prof_reporter.h @@ -26,6 +26,8 @@ #define MSVP_PROF_API __attribute__((visibility("default"))) #endif +#include "prof_callback.h" + /** * @file prof_reporter.h * @defgroup reporter the reporter group @@ -33,20 +35,6 @@ */ namespace Msprof { namespace Engine { -/// the max tag length -#define MSPROF_ENGINE_MAX_TAG_LEN (31) -/** - * @ingroup reporter - * @brief struct ReporterData - * the sturct of the data send to libmsprof - */ -struct ReporterData { - char tag[MSPROF_ENGINE_MAX_TAG_LEN + 1]; ///< the sub-type of the module, data with different tag will be writen - int deviceId; ///< the physical id of device - size_t dataLen; ///< the length of send data - unsigned char *data; ///< the data content -}; - /** * @ingroup reporter * @brief class Reporter diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index 5faca0ae..7c4f7be2 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -394,4 +394,117 @@ void DlogWithKVInner(int moduleId, int level, KeyValue *pstKVArray, int kvNum, c } #endif // LOG_CPP #endif // __cplusplus + +#ifdef LOG_CPP +#ifdef __cplusplus +extern "C" { +#endif +/** + * @ingroup slog + * @brief DlogGetlevelForC: get module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), others: invalid + * @param [out]enableEvent: 1: enable; 0: disable + * @return: module level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + */ +DLL_EXPORT int DlogGetlevelForC(int moduleId, int *enableEvent); + +/** + * @ingroup slog + * @brief DlogSetlevelForC: set module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), -1: all modules, others: invalid + * @param [in]level: log level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + * @param [in]enableEvent: 1: enable; 0: disable, others:invalid + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogSetlevelForC(int moduleId, int level, int enableEvent); + +/** + * @ingroup slog + * @brief CheckLogLevelForC: check module level enable or not + * users no need to call it because all dlog interface(include inner interface) has already called + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG + * @return: 1:enable, 0:disable + */ +DLL_EXPORT int CheckLogLevelForC(int moduleId, int logLevel); + +/** + * @ingroup slog + * @brief DlogSetAttrForC: set log attr, default pid is 0, default device id is 0, default process type is APPLICATION + * @param [in]logAttr: attr info, include pid(must be larger than 0), process type and device id(chip ID) + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogSetAttrForC(LogAttr logAttr); + +/** + * @ingroup slog + * @brief DlogForC: print log, need caller to specify level + * call CheckLogLevelForC in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]fmt: log content + */ +#define DlogForC(moduleId, level, fmt, ...) \ + do { \ + if(CheckLogLevelForC(moduleId, level) == 1) { \ + DlogInnerForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogSubForC: print log, need caller to specify level and submodule + * call CheckLogLevelForC in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]submodule: eg: engine + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]fmt: log content + */ +#define DlogSubForC(moduleId, submodule, level, fmt, ...) \ + do { \ + if(CheckLogLevelForC(moduleId, level) == 1) { \ + DlogInnerForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogWithKVForC: print log, need caller to specify level and other paramters + * call CheckLogLevelForC in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 5: trace, 6: oplog, 16: event) + * @param [in]pstKVArray: key-value array + * @param [in]kvNum: key-value element num in array + * @param [in]fmt: log content + */ +#define DlogWithKVForC(moduleId, level, pstKVArray, kvNum, fmt, ...) \ + do { \ + if(CheckLogLevelForC(moduleId, level) == 1) { \ + DlogWithKVInnerForC(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (0) + +/** + * @ingroup slog + * @brief DlogFlushForC: flush log buffer to file + */ +DLL_EXPORT void DlogFlushForC(void); + +/** + * @ingroup slog + * @brief Internal log interface, other modules are not allowed to call this interface + */ +void DlogInnerForC(int moduleId, int level, const char *fmt, ...); +void DlogWithKVInnerForC(int moduleId, int level, KeyValue *pstKVArray, int kvNum, const char *fmt, ...); + +#ifdef __cplusplus +} +#endif +#endif // LOG_CPP #endif // D_SYSLOG_H_