diff --git a/CMakeLists.txt b/CMakeLists.txt index e67b5074..bed5b995 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,7 +39,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) -if (ENABLE_OPEN_SRC) +if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) set(HI_PYTHON python3) include(cmake/external_libs/protobuf_shared.cmake) @@ -51,118 +51,132 @@ if (ENABLE_OPEN_SRC) include(cmake/external_libs/json.cmake) include(cmake/FindModule.cmake) include(cmake/intf_pub_linux.cmake) - - # if D_LINK_PATH is set in environment variables, search libraries in given path - if(DEFINED ENV{D_LINK_PATH}) - # D_LINK_PATH is set - set(GE_LIB_PATH $ENV{D_LINK_PATH}) - set(GE_SYS_ARCH "") - if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") - # x86 ubuntu - set(GE_SYS_ARCH "x86_64") - elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") - # arm euleros - set(GE_SYS_ARCH "aarch64") + add_subdirectory(tests) +else () + if (ENABLE_OPEN_SRC) + set(HI_PYTHON python3) + + include(cmake/external_libs/protobuf_shared.cmake) + include(cmake/external_libs/protobuf_static.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/gflags.cmake) + include(cmake/external_libs/gtest.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/external_libs/json.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) + + # if D_LINK_PATH is set in environment variables, search libraries in given path + if(DEFINED ENV{D_LINK_PATH}) + # D_LINK_PATH is set + set(GE_LIB_PATH $ENV{D_LINK_PATH}) + set(GE_SYS_ARCH "") + if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + # x86 ubuntu + set(GE_SYS_ARCH "x86_64") + elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + # arm euleros + set(GE_SYS_ARCH "aarch64") + else() + message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") + endif() + set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + set(STATIC_ACL_LIB ${GE_LIB_PATH}) + find_module(slog libalog.so ${GE_LIB_PATH}) + find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) + find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) + find_module(hccl libhccl.so ${GE_LIB_PATH}) + find_module(adump_server libadump_server.a ${GE_LIB_PATH}) + find_module(runtime libruntime.so ${GE_LIB_PATH}) + find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) + find_module(resource libresource.so ${GE_LIB_PATH}) + find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) + find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) + #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) else() - message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") - endif() - set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) - set(STATIC_ACL_LIB ${GE_LIB_PATH}) - find_module(slog libalog.so ${GE_LIB_PATH}) - find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) - find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) - find_module(hccl libhccl.so ${GE_LIB_PATH}) - find_module(adump_server libadump_server.a ${GE_LIB_PATH}) - find_module(runtime libruntime.so ${GE_LIB_PATH}) - find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) - find_module(resource libresource.so ${GE_LIB_PATH}) - find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) - find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) - #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) - elseif(ENABLE_GE_COV OR ENABLE_GE_UT) - add_subdirectory(tests) - else() - find_module(slog libalog.so ${ASCEND_ATC_DIR}) - find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) - if(PLATFORM STREQUAL "train") + find_module(slog libalog.so ${ASCEND_ATC_DIR}) + find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) + if(PLATFORM STREQUAL "train") + find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) + find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) + find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) + if(PRODUCT STREQUAL "flr3") + message(FATAL_ERROR "This platform is not supported in train mode, build terminated") + endif() + elseif(PLATFORM STREQUAL "inference") + find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) + find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) + find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) + if(PRODUCT STREQUAL "flr3") + elseif(PRODUCT STREQUAL "flr1") + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) + elseif(PRODUCT STREQUAL "flr2") + # flr2 ascend_hal_stub limsprof ? + else() + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) + endif() + elseif(PLATFORM STREQUAL "all") find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) - find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - if(PRODUCT STREQUAL "flr3") - message(FATAL_ERROR "This platform is not supported in train mode, build terminated") - endif() - elseif(PLATFORM STREQUAL "inference") - find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) - find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) + find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) + find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) + find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) - if(PRODUCT STREQUAL "flr3") - elseif(PRODUCT STREQUAL "flr1") - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) - elseif(PRODUCT STREQUAL "flr2") - # flr2 ascend_hal_stub limsprof ? else() - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) + message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") endif() - elseif(PLATFORM STREQUAL "all") - find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) - find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) - find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) - find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) - find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) - find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) - else() - message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") endif() - endif() - - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) - set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) - set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) - - add_subdirectory(metadef) - add_subdirectory(parser) - #add_subdirectory(metadef/graph) - #add_subdirectory(metadef/register) -elseif (ENABLE_D OR ENABLE_ACL) - # compiling with MindSpore - include(cmake/external_libs/protobuf_static.cmake) - include(cmake/external_libs/protoc.cmake) - include(cmake/external_libs/securec.cmake) - include(cmake/external_libs/json.cmake) - include(cmake/FindModule.cmake) - include(cmake/intf_pub_linux.cmake) - - # common libraries - find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) - find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) - if (ENABLE_D) - # training - find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) - find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) - endif () - - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) - add_subdirectory(metadef) -elseif(ENABLE_MS_TESTCASES) - include(cmake/external_libs/protobuf_static.cmake) - include(cmake/external_libs/protoc.cmake) - include(cmake/external_libs/securec.cmake) - include(cmake/FindModule.cmake) - include(cmake/intf_pub_linux.cmake) - - # common libraries - find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) - find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) + set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) + set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) + + add_subdirectory(metadef) + add_subdirectory(parser) + #add_subdirectory(metadef/graph) + #add_subdirectory(metadef/register) + elseif (ENABLE_D OR ENABLE_ACL) + # compiling with MindSpore + include(cmake/external_libs/protobuf_static.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/external_libs/json.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) + + # common libraries + find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + + if (ENABLE_D) + # training + find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + endif () + + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) + add_subdirectory(metadef) + elseif(ENABLE_MS_TESTCASES) + include(cmake/external_libs/protobuf_static.cmake) + include(cmake/external_libs/protoc.cmake) + include(cmake/external_libs/securec.cmake) + include(cmake/FindModule.cmake) + include(cmake/intf_pub_linux.cmake) + + # common libraries + find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) + + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) + add_subdirectory(metadef) + else() + set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) + set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) + set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) + endif() - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) - add_subdirectory(metadef) -else() - set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) - set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) - set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) -endif() + add_subdirectory(ge) -add_subdirectory(ge) +endif () \ No newline at end of file diff --git a/build.sh b/build.sh index 5931bbaa..5fde7b8d 100755 --- a/build.sh +++ b/build.sh @@ -177,6 +177,9 @@ build_graphengine() elif [ "X$ENABLE_GE_UT" = "Xon" ] then TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" + elif [ "X$ENABLE_GE_ST" = "Xon" ] + then + TARGET="graph_engine_test" elif [ "X$MINDSPORE_MODE" = "Xon" ] then TARGET="ge_common graph" @@ -234,6 +237,27 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then genhtml coverage.info fi +if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then + #prepare engine & opskernel so + mkdir -p ${OUTPUT_PATH}/plugin/nnengine + mkdir -p ${OUTPUT_PATH}/plugin/nnengine/ge_config + mkdir -p ${OUTPUT_PATH}/plugin/opskernel + cp ${BUILD_PATH}/tests/st/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine + cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config + cp ${BUILD_PATH}/tests/st/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel + #prepare st execution bin + cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} + #execute st testcase + RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} + if [[ "$?" -ne 0 ]]; then + echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" + echo -e "\033[31m${RUN_TEST_CASE}\033[0m" + exit 1; + fi + # remove plugin + rm -rf ${OUTPUT_PATH}/plugin +fi + # generate output package in tar form, including ut/st libraries/executables generate_package() { @@ -337,7 +361,7 @@ generate_package() fi } -if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then +if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$ENABLE_GE_ST" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then generate_package elif [ "X$MINDSPORE_MODE" = "Xon" ] then diff --git a/ge/ge_runtime/runtime_model.cc b/ge/ge_runtime/runtime_model.cc index 71147a4b..efaad251 100644 --- a/ge/ge_runtime/runtime_model.cc +++ b/ge/ge_runtime/runtime_model.cc @@ -25,6 +25,7 @@ #include "framework/common/op/op_parser_util.h" #include "graph/types.h" #include "task/task_factory.h" +#include "ge/common/math/math_util.h" namespace ge { namespace model_runner { @@ -500,7 +501,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model } uint64_t *buff = reinterpret_cast(const_cast(constant->weight_data.data())); uint32_t head_len = kOffsetUnit * kStringHeadElems; - if (ge::CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { + if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { GELOGE(FAILED, "Shape size is invalid"); return false; } diff --git a/ge/ge_runtime/task/aicpu_task.cc b/ge/ge_runtime/task/aicpu_task.cc index cc07365d..ddd6557b 100644 --- a/ge/ge_runtime/task/aicpu_task.cc +++ b/ge/ge_runtime/task/aicpu_task.cc @@ -83,7 +83,7 @@ bool AicpuTask::Distribute() { return false; } - GELOGI("ext info size:", ext_size); + GELOGI("ext info size: %u", ext_size); aicpu_param_head.extInfoLength = ext_size; aicpu_param_head.extInfoAddr = reinterpret_cast(ext_info_); } diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index 06165053..2169f96a 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -130,7 +130,7 @@ bool HcclTask::SetSecondaryStream() { Status ret; std::lock_guard lock(model_stream_mapping_mutex_); if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { - GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id); + GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); ret = CreateStream(hccl_secondary_stream_num, master_stream_id); if (!ret) { GELOGE(RT_FAILED, "Create hccl stream failed."); @@ -189,7 +189,7 @@ bool HcclTask::SetSecondaryStream() { } GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); } else { - GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(), + GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), master_stream_id); ret = CreateStream(hccl_secondary_stream_num, master_stream_id); if (!ret) { diff --git a/ge/ge_runtime/task/label_goto_task.cc b/ge/ge_runtime/task/label_goto_task.cc index ad93a98f..4302bff3 100644 --- a/ge/ge_runtime/task/label_goto_task.cc +++ b/ge/ge_runtime/task/label_goto_task.cc @@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { return false; } - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); return false; diff --git a/ge/ge_runtime/task/label_switch_task.cc b/ge/ge_runtime/task/label_switch_task.cc index a3c2d41a..8c795da9 100644 --- a/ge/ge_runtime/task/label_switch_task.cc +++ b/ge/ge_runtime/task/label_switch_task.cc @@ -69,7 +69,7 @@ bool LabelSwitchTask::Distribute() { return false; } label_list[i] = all_label_resource_[label_index]; - GELOGI("Case %zu: label id %zu.", i, label_index); + GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); } uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); diff --git a/tests/st/CMakeLists.txt b/tests/st/CMakeLists.txt new file mode 100644 index 00000000..3b294681 --- /dev/null +++ b/tests/st/CMakeLists.txt @@ -0,0 +1,6 @@ +project(graphengine_st) + +include(cmake/graphengine.cmake) + +add_subdirectory(framework) +add_subdirectory(testcase) \ No newline at end of file diff --git a/tests/st/cmake/graphengine.cmake b/tests/st/cmake/graphengine.cmake new file mode 100644 index 00000000..aab49b70 --- /dev/null +++ b/tests/st/cmake/graphengine.cmake @@ -0,0 +1,249 @@ +# ---- Test coverage ---- + +if (ENABLE_GE_COV) + set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -fPIC -O0 -ftest-coverage") + set(CMAKE_CXX_FLAGS "${COVERAGE_COMPILER_FLAGS}") +endif() + +# ---- Proto generate ---- + +file(GLOB_RECURSE PROTO_FILES CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/proto/*.proto") + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_FILES}) + +# ---- File glob by group ---- + +file(GLOB_RECURSE METADEF_SRCS CONFIGURE_DEPENDS + "${GE_CODE_DIR}/metadef/graph/*.cc" + "${GE_CODE_DIR}/metadef/register/*.cc" + "${GE_CODE_DIR}/metadef/register/*.cpp" + "${GE_CODE_DIR}/metadef/ops/*.cc" + "${GE_CODE_DIR}/metadef/third_party/transformer/src/*.cc" +) +file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS + "${GE_CODE_DIR}/metadef/register/*.cc" + "${GE_CODE_DIR}/metadef/register/*.cpp" +) + +file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS + "${GE_CODE_DIR}/parser/parser/common/*.cc" +) + +file(GLOB_RECURSE LOCAL_ENGINE_SRC CONFIGURE_DEPENDS + "${GE_CODE_DIR}/ge/ge_local_engine/*.cc" +) + +file(GLOB_RECURSE HOST_ENGINE_SRC CONFIGURE_DEPENDS + "${GE_CODE_DIR}/ge/host_cpu_engine/*.cc" +) + +file(GLOB_RECURSE NN_ENGINE_SRC CONFIGURE_DEPENDS + "${GE_CODE_DIR}/ge/plugin/*.cc" +) + +file(GLOB_RECURSE OFFLINE_SRC CONFIGURE_DEPENDS + "${GE_CODE_DIR}/ge/offline/*.cc" +) + +file(GLOB_RECURSE GE_SRCS CONFIGURE_DEPENDS + "${GE_CODE_DIR}/ge/*.cc" +) + +list(REMOVE_ITEM GE_SRCS ${LOCAL_ENGINE_SRC} ${HOST_ENGINE_SRC} ${NN_ENGINE_SRC} ${OFFLINE_SRC}) + +list(APPEND INCLUDE_DIRECTORIES + "${CMAKE_CURRENT_SOURCE_DIR}" + "${GE_CODE_DIR}" + "${GE_CODE_DIR}/inc" + "${GE_CODE_DIR}/metadef/inc" + "${GE_CODE_DIR}/ge" + "${GE_CODE_DIR}/ge/inc" + "${GE_CODE_DIR}/ge/ir_build" + "${GE_CODE_DIR}/metadef" + "${GE_CODE_DIR}/metadef/graph" + "${GE_CODE_DIR}/inc/external" + "${GE_CODE_DIR}/inc/framework/common" + "${GE_CODE_DIR}/metadef/inc/external" + "${GE_CODE_DIR}/metadef/inc/external/graph" + "${GE_CODE_DIR}/metadef/inc/graph" + "${GE_CODE_DIR}/inc/framework" + "${GE_CODE_DIR}/metadef/inc/common" + "${GE_CODE_DIR}/metadef/third_party" + "${GE_CODE_DIR}/metadef/third_party/transformer/inc" + "${GE_CODE_DIR}/parser" + "${GE_CODE_DIR}/parser/parser" + "${GE_CODE_DIR}/third_party/fwkacllib/inc" + "${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" + "${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" + "${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" + "${GE_CODE_DIR}/tests/ut/ge" + "${GE_CODE_DIR}/tests/ut/common" + "${CMAKE_BINARY_DIR}" + "${CMAKE_BINARY_DIR}/proto/ge" + "${CMAKE_BINARY_DIR}/proto/ge/proto" +) + +list(APPEND STUB_LIBS + c_sec + slog_stub + cce_ge_stub + runtime_stub + profiler_stub + #mmpa_stub + hccl_stub + error_manager_stub + ascend_protobuf + json +) + +# ---- Target : Local engine ---- + +add_library(localengine STATIC ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) + +target_include_directories(localengine + PUBLIC + "${INCLUDE_DIRECTORIES}" + "${GE_CODE_DIR}/ge/ge_local_engine" +) + +target_compile_definitions(localengine PRIVATE + google=ascend_private +) + +target_compile_options(localengine PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format +) + +target_link_libraries(localengine PUBLIC + $ ${STUB_LIBS} -lrt -ldl -lpthread -lgcov +) + +set_target_properties(localengine PROPERTIES CXX_STANDARD 11) + +# ---- Target : metadef graph ---- + +add_library(metadef_graph STATIC ${METADEF_SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) + +target_include_directories(metadef_graph + PUBLIC + "${INCLUDE_DIRECTORIES}" + ) + +target_compile_definitions(metadef_graph PRIVATE + google=ascend_private + FMK_SUPPORT_DUMP + ) + +target_compile_options(metadef_graph PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format + ) + +target_link_libraries(metadef_graph PUBLIC + $ ${STUB_LIBS} -lrt -ldl -lpthread -lgcov + ) + +set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) +# ---- Target : Host engine ---- + +add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC} ${PROTO_HDRS}) + +target_include_directories(host_cpu_engine + PUBLIC + "${INCLUDE_DIRECTORIES}" + "${GE_CODE_DIR}/ge/host_cpu_engine" +) + +target_compile_definitions(host_cpu_engine PRIVATE + google=ascend_private + FMK_SUPPORT_DUMP +) + +target_compile_options(host_cpu_engine PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format +) + +target_link_libraries(host_cpu_engine PUBLIC + $ ${STUB_LIBS} metadef_graph -lmmpa -L/home/hugo/Code/ge/graphengine/build/tests/depends/mmpa -lrt -ldl -lpthread -lgcov +) + +set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) + +# ---- Target : engine plugin---- +# + +add_library(nnengine SHARED ${NN_ENGINE_SRC}) + +target_include_directories(nnengine + PUBLIC + "${INCLUDE_DIRECTORIES}" + "${GE_CODE_DIR}/ge/plugin/engine" +) + +target_compile_definitions(nnengine PRIVATE + google=ascend_private +) + +target_compile_options(nnengine PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format +) + +target_link_libraries(nnengine PUBLIC + $ ${STUB_LIBS} -lrt -ldl -lpthread -lgcov +) + +set_target_properties(nnengine PROPERTIES CXX_STANDARD 11) + +# Targe: engine_conf +add_custom_target( + engine_conf.json ALL + DEPENDS ${CMAKE_BINARY_DIR}/engine_conf.json +) +add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/engine_conf.json + COMMAND cp ${GE_CODE_DIR}/ge/engine_manager/engine_conf.json ${CMAKE_BINARY_DIR}/ +) +# Targe: optimizer priority +add_custom_target( + optimizer_priority.pbtxt ALL + DEPENDS ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt +) +add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt + COMMAND cp ${GE_CODE_DIR}/ge/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_BINARY_DIR}/ +) + +# ---- Target : Graph engine ---- + +add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS} ${PROTO_HDRS}) + +target_include_directories(graphengine + PUBLIC + "${INCLUDE_DIRECTORIES}" + "${GE_CODE_DIR}/ge/host_cpu_engine" +) + +target_compile_definitions(graphengine PRIVATE + google=ascend_private + FMK_SUPPORT_DUMP +) + +target_compile_options(graphengine PRIVATE + -g --coverage -fprofile-arcs -ftest-coverage + -Werror=format +) + +target_link_libraries(graphengine PUBLIC + $ ${STUB_LIBS} + metadef_graph + localengine + host_cpu_engine + nnengine + mmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov +) + +set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) +add_dependencies(graphengine engine_conf.json optimizer_priority.pbtxt) diff --git a/tests/st/framework/CMakeLists.txt b/tests/st/framework/CMakeLists.txt new file mode 100644 index 00000000..b6a8752a --- /dev/null +++ b/tests/st/framework/CMakeLists.txt @@ -0,0 +1,16 @@ +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") +#todo +file(GLOB_RECURSE stub_engine CONFIGURE_DEPENDS + "stub_engine/*.cc" + ) +list(REMOVE_ITEM SOURCES ${stub_engine}) + +add_library(framework STATIC ${SOURCES}) + +target_include_directories(framework + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_target_properties(framework PROPERTIES CXX_STANDARD 11) + +target_link_libraries(framework PUBLIC graphengine) diff --git a/tests/st/framework/framework.cc b/tests/st/framework/framework.cc new file mode 100644 index 00000000..5ec905cd --- /dev/null +++ b/tests/st/framework/framework.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include +#include "framework.h" + +namespace ge { +namespace st { +Status Framework::SetUp() { + +} +} // namespace st +} // namespace ge \ No newline at end of file diff --git a/tests/st/framework/framework.h b/tests/st/framework/framework.h new file mode 100644 index 00000000..d693e4c8 --- /dev/null +++ b/tests/st/framework/framework.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef GRAPHENGINE_LLT_ST_FRAMEWORK_H_ +#define GRAPHENGINE_LLT_ST_FRAMEWORK_H_ +#include +#include "common/ge_inner_error_codes.h" + +namespace ge { +namespace st { +class Framework { + public: + explicit Framework() {}; + Status SetUp(); + Status TearDown(); +}; +} // namespace st +}// namespace ge + +#endif // GRAPHENGINE_LLT_ST_FRAMEWORK_H_ diff --git a/tests/st/framework/stub_engine/CMakeLists.txt b/tests/st/framework/stub_engine/CMakeLists.txt new file mode 100644 index 00000000..74890484 --- /dev/null +++ b/tests/st/framework/stub_engine/CMakeLists.txt @@ -0,0 +1,259 @@ +set(PROTO_LIST + "${METADEF_DIR}/proto/task.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) + +set(SRC_LIST + "engine/stub_engine.cc" + "ops_kernel_store/host_cpu_ops_kernel_info.cc" + "ops_kernel_store/op/op_factory.cc" + "ops_kernel_store/op/host_op.cc" +) + +set(CPU_OPS_KERNEL_LIST + "ops_kernel_store/host_cpu_ops_kernel_builder.cc" +) + +############ libfe.so ############ +add_library(fe SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(fe PRIVATE + -Werror + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(fe PRIVATE + google=ascend_private + FUNC_VISIBILITY +) + +target_include_directories(fe PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_options(fe PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(fe PRIVATE + $ + -Wl,--no-as-needed + ascend_protobuf + c_sec + graph + slog + -Wl,--as-needed +) + +############ atcstub/libfe.so ############ +add_library(atc_fe SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) + +target_compile_options(atc_fe PRIVATE + -Werror + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(atc_fe PRIVATE + google=ascend_private + FUNC_VISIBILITY +) + +target_include_directories(atc_fe PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge_atcstub + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_options(atc_fe PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(atc_fe PRIVATE + $ + -Wl,--no-as-needed + ascend_protobuf + c_sec + graph + slog + -Wl,--as-needed +) + +set_target_properties(atc_fe PROPERTIES + OUTPUT_NAME fe + LIBRARY_OUTPUT_DIRECTORY atclib +) + +############ libhost_cpu_opskernel_builder.so ############ +add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) + +target_compile_options(host_cpu_opskernel_builder PRIVATE + -Werror + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(host_cpu_opskernel_builder PRIVATE + google=ascend_private + FUNC_VISIBILITY +) + +target_include_directories(host_cpu_opskernel_builder PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_options(host_cpu_opskernel_builder PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(host_cpu_opskernel_builder PRIVATE + $ + -Wl,--no-as-needed + ascend_protobuf + c_sec + slog + graph + register + -Wl,--as-needed +) + +############ atclib/libhost_cpu_opskernel_builder.so ############ +add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) + +target_compile_options(atc_host_cpu_opskernel_builder PRIVATE + -Werror + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(atc_host_cpu_opskernel_builder PRIVATE + google=ascend_private + FUNC_VISIBILITY +) + +target_include_directories(atc_host_cpu_opskernel_builder PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_options(atc_host_cpu_opskernel_builder PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE + $ + -Wl,--no-as-needed + ascend_protobuf + c_sec + slog + graph + register + -Wl,--as-needed +) + +set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES + OUTPUT_NAME host_cpu_opskernel_builder + LIBRARY_OUTPUT_DIRECTORY atclib +) + +############ libhost_cpu_opskernel_builder.a ############ +add_library(host_cpu_opskernel_builder_static STATIC ${CPU_OPS_KERNEL_LIST}) + +target_compile_options(host_cpu_opskernel_builder_static PRIVATE + -Werror + -fno-common + -fvisibility=hidden +) + +target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE + google=ascend_private + LOG_CPP + FUNC_VISIBILITY +) + +target_include_directories(host_cpu_opskernel_builder_static PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${GE_CODE_DIR}/ge + ${GE_CODE_DIR}/inc + ${GE_CODE_DIR}/inc/external + ${GE_CODE_DIR}/inc/framework + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${METADEF_DIR}/inc/external/graph + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge + #### yellow zone #### + ${GE_CODE_DIR}/../inc + #### blue zone #### + ${GE_CODE_DIR}/third_party/fwkacllib/inc +) + +target_link_libraries(host_cpu_opskernel_builder_static PRIVATE + $ + ascend_protobuf + c_sec +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS fe host_cpu_opskernel_builder OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) + +install(TARGETS atc_fe atc_host_cpu_opskernel_builder OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib +) diff --git a/tests/st/framework/stub_engine/common/constant/constant.h b/tests/st/framework/stub_engine/common/constant/constant.h new file mode 100644 index 00000000..86ecf2f6 --- /dev/null +++ b/tests/st/framework/stub_engine/common/constant/constant.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ +#define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ + +#include + +namespace ge { +namespace host_cpu { +// engine name +const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU"; +const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE"; +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ diff --git a/tests/st/framework/stub_engine/engine/stub_engine.cc b/tests/st/framework/stub_engine/engine/stub_engine.cc new file mode 100644 index 00000000..8ee5c25f --- /dev/null +++ b/tests/st/framework/stub_engine/engine/stub_engine.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "stub_engine.h" +#include +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "host_cpu_engine/common/constant/constant.h" +#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" + +namespace fe { +AICEngine &AICEngine::Instance() { + static AICEngine instance; + return instance; +} + +Status AICEngine::Initialize(const std::map &options) { + if (ops_kernel_store_ == nullptr) { + ops_kernel_store_ = MakeShared(); + if (ops_kernel_store_ == nullptr) { + GELOGE(FAILED, "[Create][AICEngine]Make HostCpuOpsKernelInfoStore failed."); + REPORT_INNER_ERROR("E19999", "AICEngine::Initialize failed for new AICEngine."); + return FAILED; + } + } + return SUCCESS; +} + +void AICEngine::GetOpsKernelInfoStores(std::map &ops_kernel_map) { + if (ops_kernel_store_ != nullptr) { + // add buildin opsKernel to opsKernelInfoMap + ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_; + } +} + +void AICEngine::GetGraphOptimizerObjs(std::map &) { + // no optimizer for host cpu engine +} + +Status AICEngine::Finalize() { + ops_kernel_store_ = nullptr; + return SUCCESS; +} +} // namespace fe + +ge::Status Initialize(const std::map &options) { + return fe::AICEngine::Instance().Initialize(options); +} + +void GetOpsKernelInfoStores(std::map &ops_kernel_map) { + fe::AICEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); +} + +void GetGraphOptimizerObjs(std::map &graph_optimizers) { + fe::AICEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); +} + +ge::Status Finalize() { return fe::AICEngine::Instance().Finalize(); } diff --git a/tests/st/framework/stub_engine/engine/stub_engine.h b/tests/st/framework/stub_engine/engine/stub_engine.h new file mode 100644 index 00000000..65f23333 --- /dev/null +++ b/tests/st/framework/stub_engine/engine/stub_engine.h @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ +#define GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include +#include "common/opskernel/ops_kernel_info_store.h" +#include "common/optimizer/graph_optimizer.h" + +using OpsKernelInfoStorePtr = std::shared_ptr; +using GraphOptimizerPtr = std::shared_ptr; + +namespace ge { +namespace { + std::vector extern_engine_name_vec = {"fe","rts_engine","aicpu_ascend_engine","aicpu_tf_engine",} +} // namespace +/** + * host cpu engine. + * Used for the ops which executes on host. + */ +class GE_FUNC_VISIBILITY StubEngine { + public: + /** + * get HostCpuEngine instance. + * @return HostCpuEngine instance. + */ + static StubEngine &Instance(); + + virtual ~StubEngine() = default; + + /** + * When Ge start, GE will invoke this interface + * @return The status whether initialize successfully + */ + Status Initialize(const std::map &options); + + /** + * After the initialize, GE will invoke this interface + * to get the Ops kernel Store. + * @param ops_kernel_map The host cpu's ops kernel info + */ + void GetOpsKernelInfoStores(std::map &ops_kernel_map); + + /** + * After the initialize, GE will invoke this interface + * to get the Graph Optimizer. + * @param graph_optimizers The host cpu's Graph Optimizer objs + */ + void GetGraphOptimizerObjs(std::map &graph_optimizers); + + /** + * When the graph finished, GE will invoke this interface + * @return The status whether initialize successfully + */ + Status Finalize(); + + StubEngine(const StubEngine &StubEngine) = delete; + StubEngine(const StubEngine &&StubEngine) = delete; + StubEngine &operator=(const StubEngine &StubEngine) = delete; + StubEngine &operator=(StubEngine &&StubEngine) = delete; + + private: + StubEngine() = default; + OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; +}; +} // namespace ge + +extern "C" { + +/** + * When Ge start, GE will invoke this interface + * @return The status whether initialize successfully + */ +GE_FUNC_VISIBILITY ge::Status Initialize(const map &options); + +/** + * After the initialize, GE will invoke this interface to get the Ops kernel Store + * @param ops_kernel_map The host cpu's ops kernel info + */ +GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map &ops_kernel_map); + +/** + * After the initialize, GE will invoke this interface to get the Graph Optimizer + * @param graph_optimizers The host cpu's Graph Optimizer objs + */ +GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map &graph_optimizers); + +/** + * When the graph finished, GE will invoke this interface + * @return The status whether initialize successfully + */ +GE_FUNC_VISIBILITY ge::Status Finalize(); +} + +#endif // GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ diff --git a/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file mode 100644 index 00000000..8246a85d --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_cpu_ops_kernel_builder.h" +#include +#include "common/ge_inner_error_codes.h" +#include "ge/ge_api_types.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include +#include "framework/common/debug/ge_log.h" +#include "host_cpu_engine/common/constant/constant.h" +#include "register/ops_kernel_builder_registry.h" + +namespace ge { +namespace host_cpu { +REGISTER_OPS_KERNEL_BUILDER(kHostCpuOpKernelLibName, HostCpuOpsKernelBuilder); + +Status HostCpuOpsKernelBuilder::Finalize() { + return SUCCESS; +} +Status HostCpuOpsKernelBuilder::Initialize(const map &options) { + return SUCCESS; +} + +Status HostCpuOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { + OpDescPtr op_desc = ge_node.GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); + REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); + return FAILED; + } + + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); + return SUCCESS; + } + } + + const string name = ge_node.GetName(); + const string type = ge_node.GetType(); + GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize()); + + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast(i)); + Format format = output_tensor.GetFormat(); + DataType data_type = output_tensor.GetDataType(); + + int64_t mem_size = 0; + // If mem size has been set, no need reset. + if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) { + GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.", + name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size); + continue; + } + + int64_t output_mem_size = 0; + GeShape output_shape = output_tensor.GetShape(); + if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) || + (output_mem_size < 0)) { + GELOGE(FAILED, + "[Calc][TensorMemSize] fail for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + REPORT_CALL_ERROR("E19999", + "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + + TensorUtils::SetSize(output_tensor, output_mem_size); + if (op_desc->UpdateOutputDesc(static_cast(i), output_tensor) != GRAPH_SUCCESS) { + GELOGE(FAILED, + "[Update][OutputDesc] fail for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, + TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); + REPORT_CALL_ERROR("E19999", "UpdateOutputDesc failed for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", + name.c_str(), type.c_str(), i, + TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + } + + GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); + return SUCCESS; +} + +Status HostCpuOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector &tasks) { + // no need to generate device task + return SUCCESS; +} +} // namespace host_cpu +} // namespace ge \ No newline at end of file diff --git a/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file mode 100644 index 00000000..0ffc069b --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include "common/opskernel/ops_kernel_builder.h" + +namespace ge { +namespace host_cpu { +class GE_FUNC_VISIBILITY HostCpuOpsKernelBuilder : public OpsKernelBuilder { + public: + Status Initialize(const map &options) override; + + Status Finalize() override; + + Status CalcOpRunningParam(Node &node) override; + + Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) override; +}; +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ diff --git a/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file mode 100644 index 00000000..df81f5c0 --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" +#include +#include "common/constant/constant.h" +#include "ge/ge_api_types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "op/op_factory.h" + +namespace ge { +namespace host_cpu { +using domi::TaskDef; +using std::map; +using std::string; +using std::vector; + +Status HostCpuOpsKernelInfoStore::Initialize(const map &options) { + GELOGI("HostCpuOpsKernelInfoStore init start."); + OpInfo default_op_info = {.engine = kHostCpuEngineName, + .opKernelLib = kHostCpuOpKernelLibName, + .computeCost = 0, + .flagPartial = false, + .flagAsync = false, + .isAtomic = false}; + // Init op_info_map_ + auto all_ops = OpFactory::Instance().GetAllOps(); + for (auto &op : all_ops) { + op_info_map_[op] = default_op_info; + } + + GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); + + return SUCCESS; +} + +Status HostCpuOpsKernelInfoStore::Finalize() { + op_info_map_.clear(); + return SUCCESS; +} + +void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } + +bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { + if (op_desc == nullptr) { + return false; + } + return op_info_map_.count(op_desc->GetType()) > 0; +} +} // namespace host_cpu +} // namespace ge diff --git a/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file mode 100644 index 00000000..6ce597e7 --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include + +#include "common/opskernel/ops_kernel_info_store.h" + +namespace ge { +namespace host_cpu { +class GE_FUNC_VISIBILITY HostCpuOpsKernelInfoStore : public OpsKernelInfoStore { + public: + HostCpuOpsKernelInfoStore() {} + ~HostCpuOpsKernelInfoStore() override = default; + + /** + * Initialize related resources of the host cpu kernelinfo store + * @return status whether this operation success + */ + Status Initialize(const std::map &options) override; + + /** + * Release related resources of the host cpu kernel info store + * @return status whether this operation success + */ + Status Finalize() override; + + /** + * Check to see if an operator is fully supported or partially supported. + * @param op_desc OpDesc information + * @param reason unsupported reason + * @return bool value indicate whether the operator is fully supported + */ + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; + + /** + * Returns the full operator information. + * @param infos reference of a map, + * contain operator's name and detailed information + */ + void GetAllOpsKernelInfo(std::map &infos) const override; + + HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + + private: + // store op name and OpInfo key-value pair + std::map op_info_map_; +}; +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ diff --git a/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc b/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file mode 100644 index 00000000..56b3970e --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_cpu_engine/ops_kernel_store/op/host_op.h" +#include "framework/common/util.h" +#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" + +namespace ge { +namespace host_cpu { +Status HostOp::Run() { + // no need to generate device task + return SUCCESS; +} + +REGISTER_OP_CREATOR(NoOp, HostOp); +REGISTER_OP_CREATOR(Variable, HostOp); +REGISTER_OP_CREATOR(Constant, HostOp); +REGISTER_OP_CREATOR(Assign, HostOp); +REGISTER_OP_CREATOR(RandomUniform, HostOp); +REGISTER_OP_CREATOR(Add, HostOp); +REGISTER_OP_CREATOR(Mul, HostOp); +REGISTER_OP_CREATOR(ConcatV2, HostOp); +REGISTER_OP_CREATOR(Data, HostOp); +REGISTER_OP_CREATOR(Fill, HostOp); +REGISTER_OP_CREATOR(NetOutput, HostOp); +} // namespace host_cpu +} // namespace ge diff --git a/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h b/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file mode 100644 index 00000000..27cf7dae --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ + +#include "host_cpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_cpu { +class GE_FUNC_VISIBILITY HostOp : public Op { + public: + HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} + ~HostOp() override = default; + HostOp &operator=(const HostOp &op) = delete; + HostOp(const HostOp &op) = delete; + + Status Run() override; +}; +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ diff --git a/tests/st/framework/stub_engine/ops_kernel_store/op/op.h b/tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file mode 100644 index 00000000..bd227ccd --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/op/op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ + +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/node.h" + +namespace ge { +namespace host_cpu { +/** + * The base class for all op. + */ +class GE_FUNC_VISIBILITY Op { + public: + Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} + virtual ~Op() = default; + virtual Status Run() = 0; + + protected: + const RunContext &run_context_; + const Node &node_; +}; +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ diff --git a/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc b/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file mode 100644 index 00000000..af9d0010 --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" +#include "framework/common/debug/ge_log.h" +#include "common/ge_inner_error_codes.h" +#include "graph/op_desc.h" + +namespace ge { +namespace host_cpu { +OpFactory &OpFactory::Instance() { + static OpFactory instance; + return instance; +} + +std::shared_ptr OpFactory::CreateOp(const Node &node, RunContext &run_context) { + auto iter = op_creator_map_.find(node.GetType()); + if (iter != op_creator_map_.end()) { + return iter->second(node, run_context); + } + + GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); + return nullptr; +} + +void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) { + if (func == nullptr) { + GELOGW("Func is NULL."); + return; + } + + auto iter = op_creator_map_.find(type); + if (iter != op_creator_map_.end()) { + GELOGW("%s creator already exist", type.c_str()); + return; + } + + op_creator_map_[type] = func; + all_ops_.emplace_back(type); +} +} // namespace host_cpu +} // namespace ge diff --git a/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h b/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file mode 100644 index 00000000..bf67c51e --- /dev/null +++ b/tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h @@ -0,0 +1,94 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ +#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "common/ge/ge_util.h" +#include "host_cpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_cpu { +using OP_CREATOR_FUNC = std::function(const Node &, RunContext &)>; + +/** + * manage all the op, support create op. + */ +class GE_FUNC_VISIBILITY OpFactory { + public: + static OpFactory &Instance(); + + /** + * @brief create Op. + * @param [in] node share ptr of node + * @param [in] run_context run context + * @return not nullptr success + * @return nullptr fail + */ + std::shared_ptr CreateOp(const Node &node, RunContext &run_context); + + /** + * @brief Register Op create function. + * @param [in] type Op type + * @param [in] func Op create func + */ + void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func); + + const std::vector &GetAllOps() const { return all_ops_; } + + bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); } + + OpFactory(const OpFactory &) = delete; + OpFactory &operator=(const OpFactory &) = delete; + OpFactory(OpFactory &&) = delete; + OpFactory &operator=(OpFactory &&) = delete; + + private: + OpFactory() = default; + ~OpFactory() = default; + + // the op creator function map + std::map op_creator_map_; + std::vector all_ops_; +}; + +class GE_FUNC_VISIBILITY OpRegistrar { + public: + OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) { + OpFactory::Instance().RegisterCreator(type, func); + } + ~OpRegistrar() = default; + + OpRegistrar(const OpRegistrar &) = delete; + OpRegistrar &operator=(const OpRegistrar &) = delete; + OpRegistrar(OpRegistrar &&) = delete; + OpRegistrar &operator=(OpRegistrar &&) = delete; +}; + +#define REGISTER_OP_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##Op(const Node &node, RunContext &run_context) { \ + return MakeShared(node, run_context); \ + } \ + OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) +} // namespace host_cpu +} // namespace ge + +#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ diff --git a/tests/st/framework/stub_engine/proto/task.proto b/tests/st/framework/stub_engine/proto/task.proto new file mode 100644 index 00000000..edda1068 --- /dev/null +++ b/tests/st/framework/stub_engine/proto/task.proto @@ -0,0 +1,179 @@ +/* Copyright 2021. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; + KernelDefWithHandle kernel_with_handle = 40; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelDefWithHandle { + KernelContext context = 1; + + uint64 handle = 10; + string dev_func = 11; + uint32 block_dim = 12; + uint32 args_size = 13; + bytes args = 14; + bytes sm_desc = 15; + string original_kernel_key = 16; + string node_info = 17; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/tests/st/framework/stub_op_proto/array_ops.cc b/tests/st/framework/stub_op_proto/array_ops.cc new file mode 100644 index 00000000..2fd8f19c --- /dev/null +++ b/tests/st/framework/stub_op_proto/array_ops.cc @@ -0,0 +1,1763 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file array_ops.cpp + * \brief + */ +#include "array_ops.h" +#include +#include +#include + +#include "./util/op_log.h" +#include "./util/common_shape_fns.h" +#include "./util/array_ops_shape_fns.h" +#include "graph/utils/tensor_adapter.h" +#include "graph/utils/node_utils.h" +#include "./util/error_util.h" +#include "util/util.h" + +namespace ge { +const char* const kShape = "shape"; +const char* const kShapeDtype = "shape dtype"; +const char* const kAttrShape = "attr shape"; +const char* const kAttrDtype = "attr dtype"; +const char* const kAttrAxis = "attr axis"; +const char* const kAttrNumAxes = "attr num_axes"; +const char* const kPreOpInputShapeRange = "_pre_op_in_range"; +const int64_t kMaxDimNum = 8; + + +IMPLEMT_INFERFUNC(Unique, UniqueInfer) { + OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op); + GeTensorDescPtr x_input = op_desc->MutableInputDesc(0); + + GeShape x_shape; + if (WithRank(x_input, 1, x_shape) != GRAPH_SUCCESS) { + ShapeErrReport(0, op.GetName(), DebugString(x_input->GetShape().GetDims()), "1D"); + OP_LOGE(op.GetName().c_str(), "input x must be 1-D"); + return GRAPH_FAILED; + } + + DataType idx_type; + if (op.GetAttr("out_idx", idx_type) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "Op get attr out_idx failed"); + return GRAPH_FAILED; + } + + GeTensorDescPtr idx_desc = op_desc->MutableOutputDesc(1); + idx_desc->SetShape(x_shape); + idx_desc->SetOriginShape(x_shape); + idx_desc->SetDataType(idx_type); + + GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0); + y_desc->SetShape(GeShape({UNKNOWN_DIM})); + y_desc->SetOriginShape(GeShape({UNKNOWN_DIM})); + y_desc->SetDataType(x_input->GetDataType()); + if (x_shape.GetShapeSize() == UNKNOWN_DIM) { + return GRAPH_SUCCESS; + } else { + std::vector> range; + int64_t max_dim = x_shape.GetDim(0); + range.emplace_back(std::make_pair(1, max_dim)); + y_desc->SetShapeRange(range); + return GRAPH_SUCCESS; + } +} + +INFER_FUNC_REG(Unique, UniqueInfer); + +IMPLEMT_INFERFUNC(Const, ConstInfer) { + auto value = op.get_attr_value(); + auto valDesc = value.GetTensorDesc(); + auto dims = valDesc.GetShape().GetDims(); + auto attrDtype = valDesc.GetDataType(); + + TensorDesc outDesc = op.get_output_desc_y(); + outDesc.SetDataType(ge::DataType(attrDtype)); + outDesc.SetShape(Shape(dims)); + (void)op.update_output_desc_y(outDesc); + + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Const, ConstInfer); + +IMPLEMT_INFERFUNC(Constant, ConstantInfer) { + auto value = op.get_attr_value(); + auto valDesc = value.GetTensorDesc(); + auto dims = valDesc.GetShape().GetDims(); + auto attrDtype = valDesc.GetDataType(); + + TensorDesc outDesc = op.get_output_desc_y(); + outDesc.SetDataType(ge::DataType(attrDtype)); + outDesc.SetShape(Shape(dims)); + (void)op.update_output_desc_y(outDesc); + + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Constant, ConstantInfer); + +graphStatus ConstAndConstantInferFormat(ge::Operator& op) { + OP_LOGI(op.GetName().c_str(), "Const infer format start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto format = op_desc->MutableOutputDesc(0)->GetOriginFormat(); + ConstGeTensorPtr tensor_value; + if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { + OP_LOGE(op.GetName().c_str(), "Get attr value failed!"); + return GRAPH_FAILED; + } + if (!tensor_value) { + OP_LOGE(op.GetName().c_str(), "attr tensor is not exist!"); + return GRAPH_FAILED; + } + auto tensor_ptr = const_cast(tensor_value.get()); + tensor_ptr->MutableTensorDesc().SetOriginFormat(format); + tensor_ptr->MutableTensorDesc().SetFormat(format); + return GRAPH_SUCCESS; +} + +IMPLEMT_INFERFORMAT_FUNC(Const, ConstInferFormat) { + return ConstAndConstantInferFormat(op); +} + +INFER_FORMAT_FUNC_REG(Const, ConstInferFormat); + + +IMPLEMT_INFERFUNC(Snapshot, SnapshotInferFunc) { + OP_LOGI(op.GetName().c_str(), "Snapshot infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + + auto x_dims = input_desc_x->MutableShape().GetDims(); + auto x_type = input_desc_x->GetDataType(); + std::vector> x_range; + input_desc_x->GetShapeRange(x_range); + output_desc_y->SetShape(GeShape(x_dims)); + output_desc_y->SetOriginShape(GeShape(x_dims)); + output_desc_y->SetShapeRange(x_range); + output_desc_y->SetDataType(x_type); + OP_LOGI(op.GetName().c_str(), "Snapshot infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Snapshot, SnapshotInferFunc); + +IMPLEMT_INFERFUNC(GuaranteeConst, GuaranteeConstInfer) { + TensorDesc tensorDesc = op.GetInputDesc("x"); + (void)op.UpdateOutputDesc("y", tensorDesc); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(GuaranteeConst, GuaranteeConstInfer); + +IMPLEMT_INFERFUNC(BroadcastArgs, BroadcastArgsInferFunc) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto x1_desc = op_desc->MutableInputDesc("x1"); + auto x2_desc = op_desc->MutableInputDesc("x2"); + auto y_desc = op_desc->MutableOutputDesc("y"); + auto x1_dims = x1_desc->GetShape().GetDims(); + auto x2_dims = x2_desc->GetShape().GetDims(); + auto data_type = x1_desc->GetDataType(); + std::vector> x1_range; + std::vector> x2_range; + std::vector> out_range; + x1_desc->GetShapeRange(x1_range); + x2_desc->GetShapeRange(x2_range); + + + bool data_type_check = ((x1_desc->GetDataType() != DT_INT32 && x1_desc->GetDataType() != DT_INT64) || + (x2_desc->GetDataType() != DT_INT32 && x2_desc->GetDataType() != DT_INT64)); + if (data_type_check) { + string reason = "x1[" + std::to_string(x1_desc->GetDataType()) + "] + and + x2[" + + std::to_string(x1_desc->GetDataType()) + "] must DT_INT32 or DT_INT64"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); + GE_OP_LOGE(op.GetName().c_str(), "Data type check fail. x1[%u] and x2[%u] must DT_INT32 or DT_INT64", + x1_desc->GetDataType(), x2_desc->GetDataType()); + return GRAPH_PARAM_INVALID; + } + + if (x1_dims.size() > 1 || x2_dims.size() > 1) { + string reason = "x1[" + std::to_string(x1_dims.size()) + "] + and + x2[" + std::to_string(x2_dims.size()) + + "] must be less than or equal to 1"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dims", reason); + GE_OP_LOGE(op.GetName().c_str(), "Size check fail. x1[%u] and x2[%u] must be less than or equal to 1", + x1_dims.size(), x2_dims.size()); + return GRAPH_PARAM_INVALID; + } + + if (x1_dims == UNKNOWN_RANK || x2_dims == UNKNOWN_RANK) { + GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!"); + y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetDataType(data_type); + return GRAPH_SUCCESS; + } + + if (x1_dims == UNKNOWN_SHAPE && x2_dims == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!"); + y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetDataType(data_type); + y_desc->SetShapeRange(x1_range); + return GRAPH_SUCCESS; + } else if (x1_dims == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!"); + int64_t range_max = x2_dims.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetDataType(data_type); + y_desc->SetShapeRange(out_range); + return GRAPH_SUCCESS; + } else if (x2_dims == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!"); + int64_t range_max = x2_dims.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetDataType(data_type); + y_desc->SetShapeRange(out_range); + return GRAPH_SUCCESS; + } + + if (x1_dims.empty()) { + y_desc->SetShape(GeShape(x2_dims)); + } else if (x2_dims.empty()) { + y_desc->SetShape(GeShape(x1_dims)); + } else { + auto dims = x1_dims[0] > x2_dims[0] ? x1_dims : x2_dims; + y_desc->SetShape(GeShape(dims)); + } + + int64_t range_max = x1_dims.size() > x2_dims.size() ? x1_dims.size() : x2_dims.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + y_desc->SetShapeRange(out_range); + y_desc->SetDataType(x1_desc->GetDataType()); + + + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(BroadcastArgs, BroadcastArgsInferFunc); + +IMPLEMT_INFERFUNC(BroadcastGradientArgs, BroadcastGradientArgsInfer) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + + auto input_desc_x1 = op_desc->MutableInputDesc("x1"); + auto input_desc_x2 = op_desc->MutableInputDesc("x2"); + auto output_desc_y1 = op_desc->MutableOutputDesc("y1"); + auto output_desc_y2 = op_desc->MutableOutputDesc("y2"); + auto dims_x1 = input_desc_x1->MutableShape().GetDims(); + auto dims_x2 = input_desc_x2->MutableShape().GetDims(); + auto x1_type = input_desc_x1->GetDataType(); + auto x2_type = input_desc_x2->GetDataType(); + std::vector> x1_range; + std::vector> x2_range; + std::vector> out_range; + input_desc_x1->GetShapeRange(x1_range); + input_desc_x2->GetShapeRange(x2_range); + + if (dims_x1 == UNKNOWN_RANK || dims_x2 == UNKNOWN_RANK) { + GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown rank!"); + output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetDataType(x1_type); + output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetDataType(x2_type); + return GRAPH_SUCCESS; + } + // Input Dim Num must be equal or smaller than 1 + if (dims_x1 == UNKNOWN_SHAPE && dims_x2 == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "all two inputs are unknown shape!"); + output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetDataType(x1_type); + output_desc_y1->SetShapeRange(x1_range); + output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetDataType(x2_type); + output_desc_y2->SetShapeRange(x2_range); + return GRAPH_SUCCESS; + } else if (dims_x1 == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "x1 is unknown shape!"); + int64_t range_max = dims_x2.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetDataType(x1_type); + output_desc_y1->SetShapeRange(out_range); + output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetDataType(x2_type); + output_desc_y2->SetShapeRange(out_range); + return GRAPH_SUCCESS; + } else if (dims_x2 == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "x2 is unknown shape!"); + int64_t range_max = dims_x1.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetDataType(x1_type); + output_desc_y1->SetShapeRange(out_range); + output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetDataType(x2_type); + output_desc_y2->SetShapeRange(out_range); + return GRAPH_SUCCESS; + } + + GE_OP_LOGD(op.GetName().c_str(), "all two inputs are known shape!"); + int64_t range_max = dims_x1.size() == 0 ? 1 : dims_x1.size(); + std::pair pair({1, range_max}); + out_range.emplace_back(pair); + output_desc_y1->SetDataType(x1_type); + output_desc_y2->SetDataType(x2_type); + output_desc_y1->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y2->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y1->SetShapeRange(out_range); + output_desc_y2->SetShapeRange(out_range); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(BroadcastGradientArgs, BroadcastGradientArgsInfer); + +IMPLEMT_INFERFUNC(PreventGradient, PreventGradientInferFunc) { + OP_LOGI(op.GetName().c_str(), "PreventGradient infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + + auto x_dims = input_desc_x->MutableShape().GetDims(); + auto x_type = input_desc_x->GetDataType(); + std::vector> x_range; + input_desc_x->GetShapeRange(x_range); + output_desc_y->SetShape(GeShape(x_dims)); + output_desc_y->SetOriginShape(GeShape(x_dims)); + output_desc_y->SetShapeRange(x_range); + output_desc_y->SetDataType(x_type); + OP_LOGI(op.GetName().c_str(), "PreventGradient infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(PreventGradient, PreventGradientInferFunc); + +IMPLEMT_INFERFUNC(StopGradient, StopGradientInferFunc) { + OP_LOGI(op.GetName().c_str(), "StopGradient infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + + auto x_dims = input_desc_x->MutableShape().GetDims(); + auto x_type = input_desc_x->GetDataType(); + std::vector> x_range; + input_desc_x->GetShapeRange(x_range); + output_desc_y->SetShape(GeShape(x_dims)); + output_desc_y->SetOriginShape(GeShape(x_dims)); + output_desc_y->SetShapeRange(x_range); + output_desc_y->SetShapeRange(x_range); + output_desc_y->SetDataType(x_type); + OP_LOGI(op.GetName().c_str(), "StopGradient infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(StopGradient, StopGradientInferFunc); + +IMPLEMT_INFERFUNC(ExpandDims, ExpandDimsInfer) { + std::vector dep_inputs = {"axis"}; + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto node = NodeUtils::GetNodeFromOperator(op); + if (node == nullptr) { + GE_OP_LOGE(op.GetName().c_str(), "get null node ptr"); + return GRAPH_FAILED; + } + auto x_desc = op_desc->MutableInputDesc("x"); + auto axis_desc = op_desc->MutableInputDesc("axis"); + auto y_desc = op_desc->MutableOutputDesc("y"); + + op_desc->SetOpInferDepends(dep_inputs); + auto axis_type = axis_desc->GetDataType(); + auto x_type = x_desc->GetDataType(); + + if (axis_type != DT_INT32 && axis_type != DT_INT64) { + string reason = "axis dtype[" + std::to_string(axis_type) + "] must int32 or int64"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrDtype, reason); + GE_OP_LOGE(op.GetName().c_str(), "axis dtype[%d] must int32 or int64", axis_type); + return GRAPH_PARAM_INVALID; + } + + bool is_x_unknonwn_rank = x_desc->MutableShape().GetDims() == UNKNOWN_RANK ? true : false; + if (is_x_unknonwn_rank) { + GE_OP_LOGD("input x shape is unknown rank!"); + y_desc->SetUnknownDimNumShape(); + y_desc->SetDataType(x_type); + y_desc->SetOriginDataType(x_type); + return GRAPH_SUCCESS; + } + + int64_t axis_nums = axis_desc->MutableShape().GetShapeSize(); + + if (axis_nums != 1) { + // Shape::GetDims().size() == 0, means it's a scalar, its shape is []. + if (!(axis_nums == 0 && axis_desc->MutableShape().GetDims().size() == 0)) { + string reason = "axis input must be a tensor with a single value, but [" + std::to_string(axis_nums) + "] nums"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason); + GE_OP_LOGE(op.GetName().c_str(), "'axis' input must be a tensor with a single value, but %d nums", axis_nums); + return GRAPH_PARAM_INVALID; + } + } + + GeTensorPtr tensor_axis = nullptr; + graphStatus status = NodeUtils::GetInputConstData(node, "axis", tensor_axis); + if (status != GRAPH_SUCCESS) { + GE_OP_LOGI(op.GetName().c_str(), "Op get input const data of axis failed"); + auto x_shape_size = x_desc->MutableShape().GetDims().size(); + std::vector out_dims(x_shape_size + 1, UNKNOWN_DIM); + y_desc->SetShape(GeShape(out_dims)); + y_desc->SetOriginShape(GeShape(out_dims)); + y_desc->SetDataType(x_type); + y_desc->SetOriginDataType(x_type); + // infer shape range + std::vector> x_range; + (void)x_desc->GetShapeRange(x_range); + if (x_range.empty()) { + GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range!"); + return GRAPH_SUCCESS; + } + if (x_range.size() != x_shape_size) { + GE_OP_LOGE(op.GetName().c_str(), + "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size); + return GRAPH_FAILED; + } + int64_t max_range_value = 1; + for (const auto &ele : x_range) { + if (ele.second > max_range_value) { + max_range_value = ele.second; + } + } + std::vector> y_range(x_shape_size + 1, std::pair({1, max_range_value})); + y_desc->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } + + auto pbuff = tensor_axis->GetData().GetData(); + if (pbuff == nullptr) { + GE_OP_LOGE(op.GetName().c_str(), "no const data when get data from tensor!"); + return GRAPH_FAILED; + } + int64_t axis; + if (axis_type == DT_INT32) { + axis = *const_cast(reinterpret_cast(pbuff)); + } else if (axis_type == DT_INT64) { + axis = *const_cast(reinterpret_cast(pbuff)); + } + + std::vector vec_dim; + int32_t dim_num = x_desc->MutableShape().GetDimNum(); + if (axis < -1 - dim_num || axis > dim_num) { + string reason = "axis[" + std::to_string(axis) + "] is not in [" + std::to_string(-1 - dim_num) + " , " + + std::to_string(dim_num) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "axis", reason); + GE_OP_LOGE(op.GetName().c_str(), "axis[%d] is not in [%d, %d]", axis, -1 - dim_num, dim_num); + return GRAPH_PARAM_INVALID; + } + + if (axis < 0) { + axis += dim_num + 1; + } + for (int i = 0; i < dim_num; i++) { + vec_dim.push_back(x_desc->MutableShape().GetDim(i)); + } + vec_dim.emplace(vec_dim.begin() + axis, 1); + y_desc->SetShape(GeShape(vec_dim)); + y_desc->SetOriginShape(GeShape(vec_dim)); + y_desc->SetDataType(x_type); + y_desc->SetOriginDataType(x_type); + // infer shape range + auto x_shape_size = x_desc->MutableShape().GetDims().size(); + std::vector> x_range; + (void)x_desc->GetShapeRange(x_range); + if (x_range.empty()) { + GE_OP_LOGD(op.GetName().c_str(), "last op does not set shape range, so break!"); + return GRAPH_SUCCESS; + } + if (x_range.size() != x_shape_size) { + GE_OP_LOGE(op.GetName().c_str(), + "input range size num[%zu] should be same with input shape size[%zu]", x_range.size(), x_shape_size); + return GRAPH_FAILED; + } + x_range.emplace(x_range.begin() + axis, std::pair{1, 1}); + y_desc->SetShapeRange(x_range); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ExpandDims, ExpandDimsInfer); + +template +static graphStatus ValidateShape(const GeTensorPtr& tenosr, int64_t& product, int& unknow_index, GeShape& output, + Operator& op) { + int64_t dim_num = tenosr->MutableTensorDesc().MutableShape().GetDim(0); + T* shape_data = const_cast(reinterpret_cast(tenosr->GetData().GetData())); + std::vector out_dims = output.GetDims(); + if (shape_data == nullptr) { + GE_OP_LOGE(op.GetName().c_str(), "truth shape data is invalid"); + return GRAPH_PARAM_INVALID; + } + + for (int64_t i = 0; i < dim_num; i++) { + if (shape_data[i] == -1) { + if (unknow_index != -1) { + string reason = "only one dim may be -1, not both dim[ " + std::to_string(unknow_index) + "] and dim[" + + std::to_string(i) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%lld] and dim[%lld]", unknow_index, i); + return GRAPH_PARAM_INVALID; + } + unknow_index = i; + out_dims.push_back(1); + } else if (shape_data[i] < 0) { + string reason = "Size[" + std::to_string(i) + "] must be non-negative"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "Size[%lld] must be non-negative", i); + return GRAPH_PARAM_INVALID; + } else { + if (shape_data[i] != 0 && product > (INT64_MAX / shape_data[i])) { + string reason = "Mul overflow of int64, product[" + std::to_string(product) + "] shape_data[" + + std::to_string((int64_t)shape_data[i]) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "Mul overflow of int64, product[%lld] shape_data[%lld]", product, + (int64_t)shape_data[i]); + return GRAPH_PARAM_INVALID; + } + out_dims.push_back(shape_data[i]); + product *= shape_data[i]; + } + } + + output = GeShape(out_dims); + return GRAPH_SUCCESS; +} + +static graphStatus CaffeReshapeInferShape(const vector& dims, const int64_t& axis, const int64_t& num_axes, + Operator& op) { + GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto x_desc = op_desc->MutableInputDesc("x"); + auto shape_desc = op_desc->MutableInputDesc("shape"); + auto y_desc = op_desc->MutableOutputDesc("y"); + auto x_dims = x_desc->GetShape().GetDims(); + auto data_type = x_desc->GetDataType(); + + if (x_dims == UNKNOWN_RANK || dims == UNKNOWN_RANK) { + GE_OP_LOGD("Input data is unknown_rank"); + y_desc->SetShape(GeShape(UNKNOWN_RANK)); + y_desc->SetOriginShape(GeShape(UNKNOWN_RANK)); + y_desc->SetDataType(data_type); + return GRAPH_SUCCESS; + } + + if (x_dims == UNKNOWN_SHAPE) { + GE_OP_LOGD("Input data is unknown_shape."); + y_desc->SetShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + y_desc->SetDataType(data_type); + return GRAPH_SUCCESS; + } + + int64_t inferred_axis = -1; + int64_t constant_count = 1; + vector copy_axes; + + // parsing dims + for (size_t i = 0; i < dims.size(); ++i) { + const int64_t shape_dim_i = dims[i]; + if (shape_dim_i == 0) { + copy_axes.push_back(i); + } else if (shape_dim_i == -1) { + if (inferred_axis != -1) { + string reason = "only one dim may be -1, not both dim[ " + std::to_string(inferred_axis) + "] and dim[" + + std::to_string(i) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "Only one dim may be -1, not both dim[%ld] and dim[%zu]", inferred_axis, i); + return GRAPH_PARAM_INVALID; + } + inferred_axis = i; + } else { + constant_count *= shape_dim_i; + } + } + + // parsing start axis and end axis + Shape bottom_shape = op.GetInputDesc("x").GetShape(); + const int64_t bottom_shape_size = bottom_shape.GetDims().size(); + int64_t start_axis = 0; + if (axis >= 0) { + start_axis = axis; + } else { + start_axis = axis + bottom_shape_size + 1; + } + if (start_axis < 0 || start_axis > bottom_shape_size) { + int64_t range = -1 - bottom_shape_size; + // if axis >=0 , axis range [0, bottom_shape_size], else axis < 0, axis range [-1 - bottom_shape_size, -1] + // axis range [-1 - bottom_shape_size, bottom_shape_size] + string reason = "axis's range is not in [" + std::to_string(range) + ", " + std::to_string(bottom_shape_size) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); + GE_OP_LOGE(op.GetName().c_str(), "reshape param axis is invalid, axis's range is not in [%ld, %ld]", range, + bottom_shape_size); + return GRAPH_PARAM_INVALID; + } + + int64_t end_axis = 0; + if (num_axes < -1) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes, "it must be greater than or equal to -1"); + GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be greater than or equal to -1"); + return GRAPH_PARAM_INVALID; + } else if (num_axes == -1) { + end_axis = bottom_shape_size; + } else { + end_axis = start_axis + num_axes; + } + if (end_axis > bottom_shape_size) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrNumAxes, + "num_axes must be less than or equal to " + std::to_string((bottom_shape_size - start_axis))); + GE_OP_LOGE(op.GetName().c_str(), "reshape param num_axes is invalid, it must be less than or equal to %ld", + bottom_shape_size - start_axis); + return GRAPH_PARAM_INVALID; + } + + // construct top shape + vector bottom_dims = bottom_shape.GetDims(); + const int64_t num_axes_replaced = end_axis - start_axis; + const int64_t num_axes_retained = bottom_shape_size - num_axes_replaced; + const int64_t num_new_axes = dims.size(); + vector top_shape(num_axes_retained + num_new_axes); + size_t top_shape_index = 0; + for (int64_t i = 0; i < start_axis; ++i) { + top_shape[top_shape_index] = bottom_dims[i]; + top_shape_index++; + } + for (int64_t i = 0; i < num_new_axes; ++i) { + top_shape[top_shape_index] = dims[i]; + top_shape_index++; + } + for (int64_t i = end_axis; i < bottom_shape_size; ++i) { + top_shape[top_shape_index] = bottom_dims[i]; + top_shape_index++; + } + if (top_shape_index != top_shape.size()) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "infer shape size", + "top_shape_index not equal to top_shape size"); + GE_OP_LOGE(op.GetName().c_str(), "reshape infer shape faied, top_shape_index not equal to top_shape size"); + return GRAPH_FAILED; + } + + // product of [0,start_axis) + [end_axis, bottom_shape_size) + int64_t explicit_count = constant_count; + int64_t bottom_count_all = 1; + for (int i = 0; i < bottom_shape_size; ++i) { + bottom_count_all *= bottom_dims[i]; + if (i < start_axis || i >= end_axis) { + explicit_count *= bottom_dims[i]; + } + } + + // parsing dim 0 and -1 + for (size_t i = 0; i < copy_axes.size(); ++i) { + const int64_t copy_axis_index = copy_axes[i]; + if ((start_axis + copy_axis_index) >= bottom_shape_size) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, + "there was no corresponding bottom axis for dim 0"); + GE_OP_LOGE(op.GetName().c_str(), "there was no corresponding bottom axis for dim 0."); + return GRAPH_FAILED; + } + top_shape[start_axis + copy_axis_index] = bottom_dims[start_axis + copy_axis_index]; + explicit_count *= bottom_dims[start_axis + copy_axis_index]; + } + if (inferred_axis >= 0) { + if (bottom_count_all % explicit_count != 0) { + string reason = + "The shape of the input cannot be divisible by the product " + "of the specified dimensions, the product is [" + + std::to_string(explicit_count) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); + GE_OP_LOGE( + op.GetName().c_str(), + "The shape of the input cannot be divisible by the product of the specified dimensions, the product is %ld", + explicit_count); + return GRAPH_FAILED; + } + const int64_t inferred_dim = bottom_count_all / explicit_count; + top_shape[start_axis + inferred_axis] = inferred_dim; + } + + int64_t top_count_all = 1; + for (size_t i = 0; i < top_shape.size(); ++i) { + top_count_all *= top_shape[i]; + } + if (top_count_all != bottom_count_all) { + string reason = "output tensor count [ " + std::to_string(top_count_all) + "] does not match input tensor count [" + + std::to_string(bottom_count_all) + "]."; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "output tensor count %lld does not match input tensor count %ld.", top_count_all, + bottom_count_all); + return GRAPH_FAILED; + } + + // updata output shape info + TensorDesc td = op.GetOutputDesc("y"); + td.SetShape(Shape(top_shape)); + td.SetDataType(op.GetInputDesc("x").GetDataType()); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} + +bool IsEmptyTensor(GeTensorDescPtr tensor_desc) { + bool is_empty = false; + for (const auto &dim : tensor_desc->MutableShape().GetDims()) { + if (dim == 0) { + is_empty = true; + break; + } + } + return is_empty; +} + +template +graphStatus GetOutShapeFromTensor(OpDescPtr op_desc, GeTensorPtr tensor, std::vector &v_out) { + auto shape_desc = tensor->MutableTensorDesc(); + T* shape_data = const_cast(reinterpret_cast(tensor->GetData().GetData())); + if (shape_data == nullptr) { + GE_OP_LOGE(op_desc->GetName().c_str(), "const shape data is invalid"); + return GRAPH_PARAM_INVALID; + } + for (int i = 0; i < shape_desc.MutableShape().GetDim(0); i++) { + v_out.emplace_back(shape_data[i]); + } + return GRAPH_SUCCESS; +} + +graphStatus EmptyTensorProcess(const Operator &op, const GeTensorDesc &x_desc, const GeTensorPtr &shape_tensor, + GeTensorDesc &out_desc) { + GE_OP_LOGD("Start empty-tensor preprocess!"); + + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto shape_type = op_desc->MutableInputDesc("shape")->GetDataType(); + std::vector shape_shape; + graphStatus ret = GRAPH_SUCCESS; + + if (shape_type == DT_INT32) { + ret = GetOutShapeFromTensor(op_desc, shape_tensor, shape_shape); + } else if (shape_type == DT_INT64) { + ret = GetOutShapeFromTensor(op_desc, shape_tensor, shape_shape); + } else { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype, + "Dim type must be DT_INT32 or DT_INT64."); + GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64."); + return GRAPH_PARAM_INVALID; + } + + if (ret != GRAPH_SUCCESS) { + return ret; + } + + GE_OP_LOGD(op.GetName().c_str(), "x shape: %s shape shape: %s", x_desc.GetShape().ToString().c_str(), + GeShape(shape_shape).ToString().c_str()); + + int64_t num_of_neg_1 = 0; + int64_t product = 1; + for (auto &dim : shape_shape) { + if (dim == -1) { // -1 stand for highest dim here + num_of_neg_1++; + dim = 0; + } + product *= dim; + } + + // check valid + if ((num_of_neg_1 == 0 && product == 0) || (num_of_neg_1 == 1)) { + out_desc.SetShape(GeShape(shape_shape)); + out_desc.SetOriginShape(GeShape(shape_shape)); + out_desc.SetDataType(x_desc.GetDataType()); + out_desc.SetOriginDataType(x_desc.GetDataType()); + return GRAPH_SUCCESS; + } + GE_OP_LOGE(op.GetName().c_str(), + "Param is invalid!.Please check!Input shape contains -1 num is %ld, product is %ld", num_of_neg_1, product); + return GRAPH_FAILED; +} + +IMPLEMT_INFERFUNC(Reshape, ReshapeInfer) { + bool zero_flag = false; + vector attr_dims; + if (op.GetAttr("shape", attr_dims) == GRAPH_SUCCESS) { + for (size_t i = 0; i < attr_dims.size(); ++i) { + if (attr_dims[i] == 0) { + zero_flag = true; + break; + } + } + } + + std::vector dep_inputs = {"shape"}; + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + op_desc->SetOpInferDepends(dep_inputs); + auto x_desc = op_desc->MutableInputDesc("x"); + auto y_desc = op_desc->MutableOutputDesc("y"); + + int64_t attr_axis = 0; + op.GetAttr("axis", attr_axis); + int64_t attr_num_axes = -1; + op.GetAttr("num_axes", attr_num_axes); + + if (attr_axis != 0 || attr_num_axes != -1 || zero_flag) { + GE_OP_LOGI(op.GetName().c_str(), "Get reshape_param successfully, shape size is %u, axis is %ld, num_axes is %ld", + attr_dims.size(), attr_axis, attr_num_axes); + graphStatus caffe_reshape_ret = CaffeReshapeInferShape(attr_dims, attr_axis, attr_num_axes, op); + return caffe_reshape_ret; + } + + GE_OP_LOGI(op.GetName().c_str(), "Reshape infer shape start"); + GeTensorPtr tensor = nullptr; + auto node = NodeUtils::GetNodeFromOperator(op); + if (node == nullptr) { + OP_LOGE(op.GetName().c_str(), "get null node ptr!"); + return GRAPH_PARAM_INVALID; + } + graphStatus state = NodeUtils::GetInputConstData(node, "shape", tensor); + if (state != GRAPH_SUCCESS) { + GE_OP_LOGW(op.GetName().c_str(), "Op get input const data of shape failed"); + auto input_shape = op_desc->MutableInputDesc("x")->MutableShape(); + auto shape_input_desc = op_desc->MutableInputDesc("shape"); + auto shape_shape = shape_input_desc->MutableShape(); + // because shape's value stand for output shape, so it should be smaller than 1 dim + auto shape_rank = shape_shape.GetDims().size(); + if (shape_rank > 1) { + string reason = + "shape dim[" + std::to_string(shape_shape.GetDims().size()) + "] should be smaller or equal than 1"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "shape dim[%zu] should be smaller or equal than 1", + shape_shape.GetDims().size()); + return GRAPH_PARAM_INVALID; + } + if (shape_shape.GetDims() != UNKNOWN_RANK && shape_shape.GetDims() != UNKNOWN_SHAPE) { + auto x_type = op_desc->MutableInputDesc("x")->GetDataType(); + auto td = op_desc->MutableOutputDesc("y"); + int64_t rank = (shape_rank == 0) ? 0 : shape_shape.GetDims().at(0); + td->SetShape(GeShape(std::vector(rank, UNKNOWN_DIM))); + td->SetOriginShape(GeShape(std::vector(rank, UNKNOWN_DIM))); + td->SetDataType(x_type); + // calc shape range + if (input_shape.GetDims() == UNKNOWN_RANK) { + GE_OP_LOGD("input x is unknown rank!no way to set shape range!"); + return GRAPH_SUCCESS; + } + + auto input_shape_size = input_shape.GetShapeSize(); + int64_t range_max = 1; + if (input_shape_size <= 0) { + // unknown dim , by input shape range calc output range + std::vector> x_range; + (void)op_desc->MutableInputDesc("x")->GetShapeRange(x_range); + if (x_range.empty()) { + return GRAPH_SUCCESS; + } + ge::array_ops::ReshapeRangeInfer(op, x_range, range_max); + } else { + // known dim, shape size as range_max + range_max = input_shape_size; + } + range_max = (range_max > INT32_MAX) ? INT32_MAX : range_max; + std::vector> y_range(rank, {1, range_max}); + td->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } + auto x_type = op_desc->MutableInputDesc("x")->GetDataType(); + auto td = op_desc->MutableOutputDesc("y"); + td->SetShape(GeShape({-2})); + td->SetOriginShape(GeShape({-2})); + td->SetDataType(x_type); + return GRAPH_SUCCESS; + } + + if (IsEmptyTensor(x_desc)) { + return EmptyTensorProcess(op, *x_desc, tensor, *y_desc); + } + std::vector> x_range; + std::vector> y_range; + op_desc->MutableInputDesc("x")->GetShapeRange(x_range); + int64_t product = 1; + int unknow_index = -1; + GeShape output_shape; + + DataType shape_type = op_desc->MutableInputDesc("shape")->GetDataType(); + int64_t shape_size = op_desc->MutableInputDesc("shape")->MutableShape().GetShapeSize(); + graphStatus ret = GRAPH_SUCCESS; + if (shape_type == DT_INT32) { + ret = ValidateShape(tensor, product, unknow_index, output_shape, op); + } else if (shape_type == DT_INT64) { + ret = ValidateShape(tensor, product, unknow_index, output_shape, op); + } else if (shape_size > 0) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShapeDtype, "Dim type must be DT_INT32 or DT_INT64."); + GE_OP_LOGE(op.GetName().c_str(), "Dim type must be DT_INT32 or DT_INT64."); + return GRAPH_PARAM_INVALID; + } + if (ret != GRAPH_SUCCESS) { + GE_OP_LOGE(op.GetName().c_str(), "ValidateShape failed, ret: %d", ret); + return ret; + } + + auto input_shape = op_desc->MutableInputDesc("x")->MutableShape(); + int64_t input_size = input_shape.GetShapeSize(); + + // If input tensor is scalar,then input_size will return 0, assign to 1, which means convert scalar to vector. + if (input_size == 0 && output_shape.GetShapeSize() == 1) { + input_size = 1; + } + + if (unknow_index != -1) { + if (product <= 0) { + GE_OP_LOGE(op.GetName().c_str(), "Reshape Op can't infer an empty tensor"); + return GRAPH_PARAM_INVALID; + } + if (input_shape.GetShapeSize() < 0) { + GE_OP_LOGI("input x and input shape is all unknown!"); + auto td = op_desc->MutableOutputDesc("y"); + output_shape.SetDim(unknow_index, -1); + td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType()); + td->SetShape(output_shape); + td->SetOriginShape(output_shape); + td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType()); + auto max_input_dims = 1; + // If last op does not set shape range ,do not set shape range + if (x_range.empty()) { + GE_OP_LOGI(op.GetName().c_str(), "input x doesnot have shape range!"); + } else { + // If last op have already set shape range, try best to infer shape range + ge::array_ops::ReshapeRangeInfer(op, x_range, y_range, output_shape); + } + + td->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } + int64_t missing = input_size / product; + if (product * missing != input_size) { + string reason = "The shape of the input cannot be divisible from [" + std::to_string(product) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "The shape of the input cannot be divisible from %lld", product); + return GRAPH_PARAM_INVALID; + } + output_shape.SetDim(unknow_index, missing); + } + auto dims = input_shape.GetDims(); + bool is_exist_unknown_shape = false; + for (auto ele : dims) { + is_exist_unknown_shape = (ele == -1) ? true : false; + if (!is_exist_unknown_shape) { + continue; + } + } + + if (SetScalarOutputDesc(string("x"), string("y"), op_desc, output_shape)) { + return GRAPH_SUCCESS; + } + + // Shape_size is 0, means shape tensor value is [], implying convert vector/scalar to scalar + bool convert_to_scalar = + (shape_size == 0 && (input_size == 1 || (input_size == 0 && input_shape.GetDims().size() == 0))); + + // Output_shape.GetShapeSize() > 0 and input_size <= 0 for dynamic shape + bool shape_check_ok = + ((input_size == output_shape.GetShapeSize()) || ((output_shape.GetShapeSize() > 0) && (input_size <= 0)) || + (is_exist_unknown_shape && (output_shape.GetShapeSize() > 0))); + if (!shape_check_ok && !convert_to_scalar) { + string reason = "Shape size is [" + std::to_string(shape_size) + "], input tensor with [" + + std::to_string(input_size) + "] values, is input dynamic shape [" + + std::to_string(is_exist_unknown_shape) + "], but requested shape has [" + + std::to_string(output_shape.GetShapeSize()) + "] values"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), + "Shape size is %lld, input tensor with %lld values, is input dynamic shape :%d, but \ + requested shape has %lld values", + shape_size, input_size, is_exist_unknown_shape, output_shape.GetShapeSize()); + return GRAPH_PARAM_INVALID; + } + + auto td = op_desc->MutableOutputDesc("y"); + td->SetShape(output_shape); + td->SetOriginShape(output_shape); + td->SetDataType(op_desc->MutableInputDesc("x")->GetDataType()); + td->SetOriginDataType(op_desc->MutableInputDesc("x")->GetDataType()); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Reshape, ReshapeInfer); + +IMPLEMT_INFERFORMAT_FUNC(Reshape, ReshapeInferFormat) { + GE_OP_LOGI(op.GetName().c_str(), "Reshape infer format start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_descs = op_desc->GetAllInputsDescPtr(); + auto output_descs = op_desc->GetAllOutputsDescPtr(); + for (const auto& input_desc : input_descs) { + if (input_desc->GetShape().GetDimNum() < 4) { + input_desc->SetOriginFormat(FORMAT_ND); + input_desc->SetFormat(FORMAT_ND); + } + } + for (const auto& output_desc : output_descs) { + if (output_desc->GetShape().GetDimNum() < 4) { + output_desc->SetOriginFormat(FORMAT_ND); + output_desc->SetFormat(FORMAT_ND); + } + } + (void)op_desc->DefaultInferFormat(); + for (const auto& input_desc : input_descs) { + if (input_desc->GetShape().GetDimNum() < 4) { + input_desc->SetOriginFormat(FORMAT_ND); + input_desc->SetFormat(FORMAT_ND); + } + } + for (const auto& output_desc : output_descs) { + if (output_desc->GetShape().GetDimNum() < 4) { + output_desc->SetOriginFormat(FORMAT_ND); + output_desc->SetFormat(FORMAT_ND); + } + } + return GRAPH_SUCCESS; +} +INFER_FORMAT_FUNC_REG(Reshape, ReshapeInferFormat); + +IMPLEMT_VERIFIER(Squeeze, SqueezeVerify) { + GE_OP_LOGD("Enter SqueezeVerify"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto axis = op.get_attr_axis(); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto xShape = input_desc_x->MutableShape().GetDims(); + + std::vector> x_range; + input_desc_x->GetShapeRange(x_range); + if ((xShape != UNKNOWN_RANK) && (!x_range.empty()) && (x_range.size() != xShape.size())) { + // if it has set shape range, it should be same with input dim num + GE_OP_LOGE("x_shape_range num [%zu] does not match x dims_num [%zu]", x_range.size(), xShape.size()); + return GRAPH_FAILED; + } + + auto node = NodeUtils::GetNodeFromOperator(op); + if (node == nullptr) { + GE_OP_LOGE("node pointer is nullptr"); + return GRAPH_FAILED; + } + bool is_unknow = false; + auto status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknow); + if (status != GRAPH_SUCCESS) { + GE_OP_LOGE("Get node unknown shape status failed!"); + return GRAPH_FAILED; + } + if (is_unknow) { + // when input is unknown , no way to check param "axis" whether valid. Do check when running + return GRAPH_SUCCESS; + } + + if (axis.size() > 0) { + for (unsigned i = 0; i < axis.size(); i++) { + if (axis[i] < 0) + axis[i] += xShape.size(); + bool flag = (0 <= axis[i]) && (axis[i] < static_cast(xShape.size())); + if (!flag) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, + "axis value is out of range of [-rank(input), rank(input))."); + GE_OP_LOGE(op.GetName().c_str(), "axis value is out of range of [-rank(input), rank(input))."); + return GRAPH_FAILED; + } + if (!(xShape[axis[i]] == 1)) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, "input shape has dim not equal to 1."); + GE_OP_LOGE(op.GetName().c_str(), "input shape has dim not equal to 1."); + return GRAPH_FAILED; + } + } + } + GE_OP_LOGD("SqueezeVerify Success!"); + return GRAPH_SUCCESS; +} + +VERIFY_FUNC_REG(Squeeze, SqueezeVerify); + +IMPLEMT_INFERFUNC(Squeeze, SqueezeInfer) { + GE_OP_LOGD("Enter Squeeze Infershape!"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto axis = op.get_attr_axis(); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + auto input_shape = input_desc_x->MutableShape(); + int64_t dim_size = input_shape.GetDimNum(); + auto x_data_type = input_desc_x->GetDataType(); + int32_t axis_num = axis.size(); + + // process -2(UnknownRank) + if (input_shape.GetDims() == UNKNOWN_RANK) { + GE_OP_LOGD("Input x shape is -2!"); + output_desc_y->SetShape(GeShape(UNKNOWN_RANK)); + output_desc_y->SetOriginShape(GeShape(UNKNOWN_RANK)); + output_desc_y->SetDataType(x_data_type); + return GRAPH_SUCCESS; + } + + std::vector> x_range; + std::vector> y_range; + input_desc_x->GetShapeRange(x_range); + + std::unordered_set squeeze_dims; + for (int32_t i = 0; i < axis_num; ++i) { + int32_t dim = axis[i]; + if (dim < -dim_size || dim >= dim_size) { + string reason = "Tried to squeeze dim index[" + std::to_string(dim) + "] for tensor with [" + + std::to_string(dim_size) + "] dimensions"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); + GE_OP_LOGE(op.GetName().c_str(), "Tried to squeeze dim index[%d] for tensor with [%lld] dimensions", dim, + dim_size); + return GRAPH_FAILED; + } + if (dim < 0) { + dim = dim_size + dim; + } + squeeze_dims.insert(dim); + } + + vector out_shape; + for (int i = 0; i < dim_size; i++) { + auto exist_dim = input_shape.GetDim(i); + // If squeeze_set is non-empty, only squeeze those dimensions. + if (!squeeze_dims.empty()) { + if (squeeze_dims.count(i) > 0) { + // If dim is -1 and been pointed by axis , do think -1 is 1.because no method to do verify + if (exist_dim != 1 && exist_dim != UNKNOWN_DIM) { + string reason = "Can not squeeze dim[" + std::to_string(i) + "], expected a dimension of 1, got [" + + std::to_string(exist_dim) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kShape, reason); + GE_OP_LOGE(op.GetName().c_str(), "Can not squeeze dim[%d], expected a dimension of 1, got %lld", i, + exist_dim); + return GRAPH_FAILED; + } + } else { + out_shape.emplace_back(exist_dim); + // after verified, it has ensure x_range ele num is same with dims num + if (!x_range.empty()) { + y_range.emplace_back(x_range[i]); + } + } + } else { + // Copy over all non-1-length dimensions. + // here no methed to ensure which -1 is 1, so do warning + if (exist_dim != 1) { + if (exist_dim == -1) { + GE_OP_LOGW("the [%d] dim is -1, it will not execute squeeze on it! maybe influence result", exist_dim); + } + out_shape.emplace_back(exist_dim); + // after verified, it has ensure x_range ele num is same with dims num + if (!x_range.empty()) { + y_range.emplace_back(x_range[i]); + } + } + } + } + + output_desc_y->SetShape(GeShape(out_shape)); + output_desc_y->SetOriginShape(GeShape(out_shape)); + output_desc_y->SetDataType(x_data_type); + if (!y_range.empty()) { + output_desc_y->SetShapeRange(y_range); + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Squeeze, SqueezeInfer); + +IMPLEMT_INFERFUNC(Unsqueeze, UnsqueezeInfer) { + auto axis_arr = op.get_attr_axes(); + auto axis_nums = axis_arr.size(); + if (axis_nums <= 0) { + string reason = "Axis_nums[" + std::to_string(axis_nums) + "] must be greater than 0"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); + GE_OP_LOGE(op.GetName().c_str(), "Axis_nums[%zu] must be greater than 0", axis_nums); + return GRAPH_PARAM_INVALID; + } + std::unordered_set values(axis_arr.begin(), axis_arr.end()); + if (values.size() != axis_arr.size()) { + string reason = "Axis attribute must not contain any duplicates."; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); + GE_OP_LOGE(op.GetName().c_str(), "Axis attribute must not contain any duplicates."); + return GRAPH_PARAM_INVALID; + } + Shape input_shape = op.get_input_desc_x().GetShape(); + int64_t dim_num = input_shape.GetDimNum() + axis_nums; + std::vector vec_dim(dim_num, 0); + + for (size_t i = 0; i < axis_nums; i++) { + int64_t axis = axis_arr[i]; + if ((axis < -dim_num) || (axis > (dim_num - 1))) { + string reason = "axis[" + std::to_string(axis_nums) + "]'s range is not in [" + std::to_string(-dim_num) + ", " + + std::to_string(dim_num - 1) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), kAttrAxis, reason); + GE_OP_LOGE(op.GetName().c_str(), "Axis %ld not in [%ld, %ld]", axis, -dim_num, dim_num); + return GRAPH_PARAM_INVALID; + } + if (axis < 0) { + axis += dim_num; + } + vec_dim.at(axis) = 1; + } + int64_t index = 0; + for (int64_t i = 0; i < dim_num; i++) { + if (vec_dim.at(i) != 1) { + vec_dim.at(i) = input_shape.GetDim(index); + index++; + } + } + + TensorDesc td = op.get_output_desc_y(); + td.SetShape(Shape(vec_dim)); + td.SetDataType(op.get_input_desc_x().GetDataType()); + (void)op.update_output_desc_y(td); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Unsqueeze, UnsqueezeInfer); + +IMPLEMT_INFERFUNC(Rank, RankInfer) { + OP_LOGI(op.GetName().c_str(), "Rank infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + std::vector oShapeVector; + output_desc_y->SetShape(GeShape(oShapeVector)); + output_desc_y->SetOriginShape(GeShape(oShapeVector)); + output_desc_y->SetDataType(DT_INT32); + OP_LOGI(op.GetName().c_str(), "Rank infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Rank, RankInfer); + +IMPLEMT_INFERFUNC(Size, SizeInfer) { + OP_LOGI(op.GetName().c_str(), "Size infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + std::vector oShapeVector; + output_desc_y->SetShape(GeShape(oShapeVector)); + + DataType out_type = DT_INT32; + GeAttrValue out_type_value; + op_desc->GetAttr("dtype", out_type_value); + out_type_value.GetValue(out_type); + output_desc_y->SetDataType(out_type); + OP_LOGI(op.GetName().c_str(), "Size infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Size, SizeInfer); + +COMMON_INFER_FUNC_REG(Data, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +COMMON_INFER_FUNC_REG(PlaceHolder, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +COMMON_INFER_FUNC_REG(End, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); + +IMPLEMT_INFERFUNC(PlaceholderWithDefault, PlaceholderWithDefaultInfer) { + TensorDesc input_desc = op.GetInputDesc("x"); + auto dims = input_desc.GetShape().GetDims(); + auto data_type = input_desc.GetDataType(); + + TensorDesc output_desc = op.GetOutputDesc("y"); + output_desc.SetDataType(ge::DataType(data_type)); + output_desc.SetShape(Shape(dims)); + (void)op.UpdateOutputDesc("y", output_desc); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(PlaceholderWithDefault, PlaceholderWithDefaultInfer); + +IMPLEMT_INFERFUNC(Shape, ShapeInfer) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto td = op_desc->MutableOutputDesc("y"); + auto input_dims = op_desc->MutableInputDesc("x")->MutableShape().GetDims(); + if (input_dims == UNKNOWN_RANK) { + td->SetShape(ge::GeShape(UNKNOWN_SHAPE)); + td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE)); + td->SetShapeRange(std::vector>{{1,kMaxDimNum}}); + } else { + int64_t size = static_cast(input_dims.size()); + std::vector size_v{size}; + td->SetShape(ge::GeShape(size_v)); + td->SetOriginShape(ge::GeShape(size_v)); + } + uint32_t out_type = DT_INT32; + (void)op.GetAttr("dtype", out_type); + td->SetDataType((DataType)out_type); + + std::vector> inRange; + op_desc->MutableInputDesc("x")->GetShapeRange(inRange); + if (!inRange.empty()) { + std::vector pre_op_range; + pre_op_range.resize(2*inRange.size()); + for (int i = 0; i < pre_op_range.size(); i = i + 2) { + pre_op_range[i] = inRange[i/2].first; + pre_op_range[i + 1] = inRange[i/2].second; + } + ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range); + OP_LOGD(op.GetName().c_str(), "Shape op set pre_op_range success"); + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Shape, ShapeInfer); + +IMPLEMT_INFERFUNC(ShapeN, ShapeNInfer) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + for (size_t i = 0; i < op.GetInputsSize(); i++) { + auto td = op_desc->MutableOutputDesc(i); + auto input_dims = op_desc->MutableInputDesc(i)->MutableShape().GetDims(); + if (input_dims == UNKNOWN_RANK) { + td->SetShape(ge::GeShape(UNKNOWN_SHAPE)); + td->SetOriginShape(ge::GeShape(UNKNOWN_SHAPE)); + td->SetShapeRange(std::vector>{{1,kMaxDimNum}}); + } else { + int64_t size = static_cast(input_dims.size()); + GE_OP_LOGD(op.GetName().c_str(), "output value %ld", size); + std::vector size_v{size}; + td->SetShape(ge::GeShape(size_v)); + td->SetOriginShape(ge::GeShape(size_v)); + } + uint32_t out_type = DT_INT32; + (void)op.GetAttr("dtype", out_type); + td->SetDataType((DataType)out_type); + + std::vector> inRange; + op_desc->MutableInputDesc(i)->GetShapeRange(inRange); + if (!inRange.empty()) { + std::vector pre_op_range; + pre_op_range.resize(2*inRange.size()); + for (int i = 0; i < pre_op_range.size(); i = i + 2) { + pre_op_range[i] = inRange[i/2].first; + pre_op_range[i + 1] = inRange[i/2].second; + } + ge::AttrUtils::SetListInt(*td, kPreOpInputShapeRange, pre_op_range); + OP_LOGD(op.GetName().c_str(), "ShapeN op set pre_op_range success"); + } + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ShapeN, ShapeNInfer); + +IMPLEMT_INFERFUNC(IdentityN, IdentityNInfer) { + OP_LOGI(op.GetName().c_str(), "IdentityN infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + for (size_t i = 0; i < op.GetInputsSize(); i++) { + + auto input_desc = op_desc->MutableInputDesc(i); + auto input_dims = input_desc->MutableShape().GetDims(); + auto output_desc = op_desc->MutableOutputDesc(i); + auto intput_dtype = input_desc->GetDataType(); + + std::vector> input_range; + input_desc->GetShapeRange(input_range); + output_desc->SetShape(GeShape(input_dims)); + output_desc->SetOriginShape(GeShape(input_dims)); + output_desc->SetDataType(intput_dtype); + output_desc->SetShapeRange(input_range); + } + + OP_LOGI(op.GetName().c_str(), "IdentityN infershape end"); + + return GRAPH_SUCCESS; + +} + +INFER_FUNC_REG(IdentityN, IdentityNInfer); + +IMPLEMT_INFERFUNC(Identity, IdentityInfer) { + OP_LOGI(op.GetName().c_str(), "Identity infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + + std::vector vec_dim; + vec_dim = input_desc_x->MutableShape().GetDims(); + + std::vector> x_range; + input_desc_x->GetShapeRange(x_range); + + DataType data_type = input_desc_x->GetDataType(); + + output_desc_y->SetDataType(data_type); + output_desc_y->SetShape(GeShape(vec_dim)); + output_desc_y->SetOriginShape(GeShape(vec_dim)); + output_desc_y->SetShapeRange(x_range); + OP_LOGI(op.GetName().c_str(), "Identity infershape end"); + return GRAPH_SUCCESS; + +} + +INFER_FUNC_REG(Identity, IdentityInfer); + +IMPLEMT_INFERFUNC(ReadVariableOp, ReadVariableOpInfer) { + TensorDesc input_desc = op.GetInputDesc("x"); + (void)op.UpdateOutputDesc("y", input_desc); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ReadVariableOp, ReadVariableOpInfer); + +template +static void CaclDims(const Tensor& data, std::vector& vec_dim) { + int32_t size = data.GetSize() / sizeof(T); + for (int32_t i = 0; i < size; i++) { + T dim = *((T*)data.GetData() + i); + if (dim != 0) { + vec_dim.push_back(dim); + } else { + vec_dim.clear(); + break; + } + } +} + +template +static void CaclDims(const GeTensorPtr& data, std::vector& vec_dim) { + int32_t size = data->GetData().GetSize() / sizeof(T); + for (int32_t i = 0; i < size; i++) { + void* data_ptr = (void*)data->GetData().GetData(); + if (data_ptr == nullptr) { + return; + } + T dim = *((T*)data_ptr + i); + if (dim != 0) { + vec_dim.push_back(dim); + } else { + vec_dim.clear(); + break; + } + } +} + +IMPLEMT_INFERFUNC(Empty, EmptyInfer) { + OP_LOGI(op.GetName().c_str(), "Empty infershape start"); + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + std::vector dep_inputs = {"shape"}; + op_desc->SetOpInferDepends(dep_inputs); + auto input_desc_shape = op_desc->MutableInputDesc("shape"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + auto dtype = op.get_attr_dtype(); + + std::vector> shape_range; + std::vector> y_range; + input_desc_shape->GetShapeRange(shape_range); + + DataType data_type = input_desc_shape->GetDataType(); + std::vector vec_dim; + if (data_type == DT_INT32) { + vec_dim = input_desc_shape->MutableShape().GetDims(); + } else { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "Empty only support shape type 'DT_INT32'"); + GE_OP_LOGE(op.GetName().c_str(), "Empty only support shape type 'DT_INT32'"); + return GRAPH_PARAM_INVALID; + } + + if (vec_dim == UNKNOWN_RANK) { + GE_OP_LOGD(op.GetName().c_str(), "all inputs are unknown rank!"); + output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetDataType((DataType)dtype); + return GRAPH_SUCCESS; + } + + if (vec_dim == UNKNOWN_SHAPE) { + GE_OP_LOGD(op.GetName().c_str(), "shape is unknown shape!"); + std::pair pair({1, shape_range.size()}); + y_range.emplace_back(pair); + output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetDataType((DataType)dtype); + output_desc_y->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } + + auto node = NodeUtils::GetNodeFromOperator(op); + if (node == nullptr) { + OP_LOGE(op.GetName().c_str(), "Get null node ptr."); + return GRAPH_PARAM_INVALID; + } + + GeTensorPtr shape_data; + std::vector shape_dims; + auto result = NodeUtils::GetInputConstData(node, "shape", shape_data); + if(result == GRAPH_SUCCESS) { + DataType data_type = shape_data->GetTensorDesc().GetDataType(); + if (data_type == DT_INT32) { + CaclDims(shape_data,shape_dims); + } else if (data_type == DT_INT64) { + CaclDims(shape_data, shape_dims); + } + + OP_LOGD(op.GetName().c_str(), "Get input const data success."); + std::pair pair({1,shape_range.size()}); + y_range.emplace_back(pair); + output_desc_y->SetShape(GeShape(shape_dims)); + output_desc_y->SetOriginShape(GeShape(shape_dims)); + output_desc_y->SetDataType((DataType)dtype); + output_desc_y->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } else { + OP_LOGD(op.GetName().c_str(), "Get input const data failed!"); + std::pair pair({1,shape_range.size()}); + y_range.emplace_back(pair); + output_desc_y->SetShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetOriginShape(GeShape(UNKNOWN_SHAPE)); + output_desc_y->SetDataType((DataType)dtype); + output_desc_y->SetShapeRange(y_range); + return GRAPH_SUCCESS; + } + + output_desc_y->SetShape(GeShape(vec_dim)); + output_desc_y->SetOriginShape(GeShape(vec_dim)); + output_desc_y->SetDataType((DataType)dtype); + OP_LOGD(op.GetName().c_str(), "Empty infershape end"); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Empty, EmptyInfer); + +IMPLEMT_INFERFUNC(Where, WhereInfer) { + OpDescPtr op_desc = OpDescUtils::GetOpDescFromOperator(op); + GeTensorDescPtr x_desc = op_desc->MutableInputDesc(0); + + GeShape x_shape; + if (WithRankAtLeast(x_desc, 1, x_shape) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input x must be at least 1D."); + return GRAPH_FAILED; + } + + if (WithRankAtMost(x_desc, 5, x_shape) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input x must be at most 5D."); + return GRAPH_FAILED; + } + + GeTensorDescPtr y_desc = op_desc->MutableOutputDesc(0); + y_desc->SetDataType(DT_INT64); + + vector y_shape; + auto input_dims = x_shape.GetDims(); + int64_t input_shape_size = x_shape.GetShapeSize(); + if (input_shape_size != UNKNOWN_DIM) { + // input shape: known + y_shape.push_back(UNKNOWN_DIM); + y_shape.push_back(input_dims.size()); + + std::vector> range; + int64_t dims_num = x_shape.GetDimNum(); + range.emplace_back(std::make_pair(1, input_shape_size)); + range.emplace_back(std::make_pair(dims_num, dims_num)); + y_desc->SetShapeRange(range); + } else { + if (input_dims == UNKNOWN_RANK) { + // input shape: unknown rank + y_shape.push_back(UNKNOWN_DIM); + y_shape.push_back(UNKNOWN_DIM); + } else { + // input shape: unknown dims + y_shape.push_back(UNKNOWN_DIM); + y_shape.push_back(input_dims.size()); + } + } + + y_desc->SetShape(GeShape(y_shape)); + y_desc->SetOriginShape(GeShape(y_shape)); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Where, WhereInfer); + +IMPLEMT_INFERFUNC(TransShape, TransShapeInfer) { + TensorDesc y_desc = op.GetOutputDesc("y"); + vector output_shape; + auto ret = op.GetAttr("outShape", output_shape); + if (ret != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "Failed to get attribute value."); + return GRAPH_SUCCESS; + } + y_desc.SetShape(Shape(output_shape)); + if (op.UpdateOutputDesc("y", y_desc) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(TransShape, TransShapeInfer); + +// ----------------SortV2 Begin------------------- +IMPLEMT_INFERFUNC(SortV2, SortV2InferShape) { + TensorDesc tensordesc_input = op.GetInputDesc("x"); + Shape input_shape = tensordesc_input.GetShape(); + DataType input_dtype = tensordesc_input.GetDataType(); + std::vector dims_input = input_shape.GetDims(); + + TensorDesc tensordesc_output1 = op.GetOutputDesc("y"); + + tensordesc_output1.SetShape(ge::Shape(dims_input)); + + tensordesc_output1.SetDataType(input_dtype); + + (void)op.UpdateOutputDesc("y", tensordesc_output1); + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(SortV2, SortV2Verify) { return GRAPH_SUCCESS; } + +INFER_FUNC_REG(SortV2, SortV2InferShape); +VERIFY_FUNC_REG(SortV2, SortV2Verify); +// ----------------SortV2 END--------------------- + +// ----------------Expand Begin------------------- +template static bool ExpandCalDim(const Tensor &data, + std::vector &vec_dim, + std::vector &vec_x) { + uint32_t size_shape = data.GetSize() / sizeof(T); + uint32_t size_x = vec_x.size(); + if (size_shape < size_x) { + uint32_t diff = size_x - size_shape; + for (int32_t i = 0; i < size_x; i++) { + if (i < diff) { + vec_dim.push_back(vec_x[i]); + } else { + T dim = *((T *)data.GetData() + (i - diff)); + if ((vec_x[i] != dim) && (vec_x[i] != 1) && (dim != 1)) { + return false; + } + if (vec_x[i] > dim) { + vec_dim.push_back(vec_x[i]); + } else { + vec_dim.push_back(dim); + } + } + } + } else { + uint32_t diff = size_shape - size_x; + for (int32_t i = 0; i < size_shape; i++) { + T dim = *((T *)data.GetData() + i); + if (i < diff) { + vec_dim.push_back(dim); + } else { + if ((vec_x[i - diff] != dim) && (vec_x[i-diff] != 1) && (dim != 1)) { + return false; + } + if (vec_x[i - diff] > dim) { + vec_dim.push_back(vec_x[i - diff]); + } else { + vec_dim.push_back(dim); + } + } + } + } + return true; +} + +IMPLEMT_COMMON_INFERFUNC(ExpandInferShape) { + Shape x_shape = op.GetInputDesc("x").GetShape(); + DataType x_dtype = op.GetInputDesc("x").GetDataType(); + std::vector dims_x = x_shape.GetDims(); + Tensor data; + std::vector vec_dim; + TensorDesc td = op.GetOutputDesc("y"); + if (op.GetInputConstData("shape", data) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "Get constValue failed of [shape]"); + return GRAPH_FAILED; + } else { + DataType data_type = data.GetTensorDesc().GetDataType(); + std::vector vec_dim; + if (data_type == DT_INT32) { + if (!ExpandCalDim (data, vec_dim, dims_x)) { + OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!"); + return GRAPH_FAILED; + } + } else if (data_type == DT_INT64) { + if (!ExpandCalDim (data, vec_dim, dims_x)) { + OP_LOGE(op.GetName().c_str(), "Data shape are not compatible!"); + return GRAPH_FAILED; + } + } else { + OP_LOGE(op.GetName().c_str(), "Data type not supported!"); + return GRAPH_PARAM_INVALID; + } + + td.SetShape(ge::Shape(vec_dim)); + td.SetDataType(x_dtype); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; + } +} + +COMMON_INFER_FUNC_REG(Expand, ExpandInferShape); +// ----------------Expand END--------------------- + +// ----------------ExpandD Begin------------------- +IMPLEMT_COMMON_INFERFUNC(ExpandDInferShape) { + Shape x_shape = op.GetInputDesc("x").GetShape(); + DataType x_dtype = op.GetInputDesc("x").GetDataType(); + std::vector shape; + op.GetAttr("shape", shape); + std::vector dims_x = x_shape.GetDims(); + TensorDesc td = op.GetOutputDesc("y"); + + std::vector dim_vec; + if (shape.size() < dims_x.size()) { + std::vector dims_tmp = shape; + shape = dims_x; + dims_x = dims_tmp; + } + if (shape.size() != dims_x.size()) { + int dec = shape.size() - dims_x.size(); + for (int i = 0; i < dec; i++) { + dims_x.insert(dims_x.begin(), (int64_t)1); + } + } + for (size_t i = 0; i < shape.size(); i++) { + if ((shape[i] != dims_x[i]) && (shape[i] != 1) && (dims_x[i] != 1)) { + OP_LOGE(op.GetName().c_str(), "The input shape and attr shape are not compatible."); + return GRAPH_FAILED; + } + if (shape[i] > dims_x[i]) { + dim_vec.push_back(shape[i]); + } else { + dim_vec.push_back(dims_x[i]); + } + } + td.SetShape(ge::Shape(dim_vec)); + td.SetDataType(x_dtype); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ExpandD, ExpandDInferShape); +// ----------------Expand END--------------------- +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/array_ops.h b/tests/st/framework/stub_op_proto/array_ops.h new file mode 100644 index 00000000..0375894f --- /dev/null +++ b/tests/st/framework/stub_op_proto/array_ops.h @@ -0,0 +1,711 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file array_ops.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ +#define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { +/** +*@brief Finds unique elements in a 1D tensor. \n + +*@par Inputs: +*x: 1D tensor. +*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" +are 0D scalars. \n + +*@par Attributes: +*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n + +*@par Outputs: +*@li y: "x" in the unique output "y". +*@li idx: A tensor the same size as "x". The index of each value of "x". \n + +*@attention Constraints: +*Unique runs on the Ascend AI CPU, which delivers poor performance. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Unique. +*/ + +REG_OP(Unique) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) + .ATTR(out_idx, Type, DT_INT32) + .OP_END_FACTORY_REG(Unique) + +/** +*@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. +Operator Const has the same definition as operator Constant. \n + +*@par Attributes: +*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n + +*@par Outputs: +*y: A constant tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Const. +*/ +REG_OP(Const) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Const) + +/** +*@brief Creates a constant tensor for training. \n + +*@par Attributes: +*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n + +*@par Outputs: +*y: The constant tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Const. +*/ +REG_OP(Constant) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Constant) + +/** +*@brief Returns a copy of the input tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Snapshot. +*/ +REG_OP(Snapshot) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Snapshot) + +/** +*@brief Gives a guarantee to the runtime that the input tensor is a constant. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator GuaranteeConst. +*/ +REG_OP(GuaranteeConst) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(GuaranteeConst) + +/** +*@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n + +*@par Inputs: +*@li x1: A tensor of type int32 or int64. A shape. +*@li x2: A tensor of the same type as "x1". The other shape. \n + +*@par Outputs: +*y: A tensor. The broadcasted shape. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BroadcastArgs. +*/ +REG_OP(BroadcastArgs) + .INPUT(x1, TensorType({DT_INT32, DT_INT64})) + .INPUT(x2, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(BroadcastArgs) + +/** +*@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*message: Will be printed in the error at the attempt to request a gradient. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PreventGradient. +*/ +REG_OP(PreventGradient) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(message, String, "") + .OP_END_FACTORY_REG(PreventGradient) + +/** +*@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n + +*@par Inputs: +*@li x1: A tensor of type int32 or int64. +*@li x2: A tensor of type int32 or int64. +"x2" has the same type as "x1". \n + +*@par Outputs: +*@li y1: A tensor. Reduction indices of "x1". +*@li y2: A tensor. Reduction indices of "x2". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BroadcastGradientArgs. +*/ +REG_OP(BroadcastGradientArgs) + .INPUT(x1, TensorType({DT_INT32, DT_INT64})) + .INPUT(x2, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y1, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y2, TensorType({DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(BroadcastGradientArgs) + +/** +*@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped. + + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: The input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator StopGradient. +*/ +REG_OP(StopGradient) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(StopGradient) + +/** +*@brief Return a tensor with the same shape and contents as input. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Identity. +*/ +REG_OP(Identity) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(Identity) + +/** +*@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n + +*@par Inputs: +*x: A list of input tensors. It's a dynamic input \n + +*@par Outputs: +*y: A list of Tensor objects, with the same length as the input tensor list. +It's a dynamic output. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator IdentityN. +*/ +REG_OP(IdentityN) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(IdentityN) + +/** +*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: A tensor. +*@li axis: The dimension index at which to expand. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ExpandDims. +*/ +REG_OP(ExpandDims) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .INPUT(axis, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OP_END_FACTORY_REG(ExpandDims) + +/** +*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: Original tensor. +*@li axis: List of ints. \n + +*@par Outputs: +*y: Reshape tensor with same data as input. \n + +*@par Third-party framework compatibility +*Compatible with the Onnx operator Unsqueeze. +*/ + +REG_OP(Unsqueeze) + .INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) + .ATTR(axes, ListInt, {}) + .OP_END_FACTORY_REG(Unsqueeze) + +/** +*@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n + +*@par Inputs: +*@li x: A tensor. +*@li shape: A tensor. Defines the shape of the output tensor. \n + +*@par Attributes: +*@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0". +*@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n + +*@par Outputs: +*y: A tensor. \n + +*@par Attention: +*This operator cannot be directly called by the acllopExecute API. \n + +*@par Third-party framework compatibility +*@li Compatible with the TensorFlow operator Reshape. +*@li Compatible with the Caffe operator Reshape. +*/ +REG_OP(Reshape) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .INPUT(shape, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, + DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(axis, Int, 0) + .ATTR(num_axes, Int, -1) + .OP_END_FACTORY_REG(Reshape) + +/** +*@brief Removes dimensions of size 1 from the shape of a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Squeeze. +*/ +REG_OP(Squeeze) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(axis, ListInt, {}) + .OP_END_FACTORY_REG(Squeeze) + +/** +*@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. The rank of input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Rank. +*/ +REG_OP(Rank) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(Rank) + +/** +*@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n + +*@par Outputs: +*y: A tensor. The size of the input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Size. +*/ +REG_OP(Size) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32,DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(Size) + +/** +*@brief Input data for other operators. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*index: Index of the input tensor.The data type must be int32 or int64. +Assume that net has three data nodes, one should be set 0, another should +be set 1, and the left should be set 2. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the Caffe operator Data. +*/ +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +/** +*@brief Inserts a placeholder for a tensor that will be always fed. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*@li peerIndex: An integer type. The index of the corresponding "end" node connected to. +*@li parentId: A string, used to check if the nodes are from the saved parent node. +*@li parentOpType: A string. Op type of the original node. +*@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n + +*@par Outputs: +*y: The created placeholder tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PlaceHolder. +*/ +REG_OP(PlaceHolder) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to + .ATTR(parentId, String, "") // check if these node are from save parent node + .ATTR(parentOpType, String, "") // op type of original node + .ATTR(anchorIndex, Int, 0) // check if these node are from save anchor + .OP_END_FACTORY_REG(PlaceHolder) + +/** +*@brief Inserts a placeholder with default value for a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*@li dtype: data type of tensor. +*@li shape: tensor shape. \n + +*@par Outputs: +*y: The created placeholder tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator PlaceholderWithDefault. +*/ +REG_OP(PlaceholderWithDefault) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .REQUIRED_ATTR(shape, ListInt) + .OP_END_FACTORY_REG(PlaceholderWithDefault) + +/** +*@brief Reads and returns the value of the input variable tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ReadVariableOp. +*/ +REG_OP(ReadVariableOp) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(ReadVariableOp) + +/** +*@brief Mark outputs of one sub graph which partitioned by engine type. + +*@par Inputs: +*x: A tensor. \n + +*@par Outputs: +*y: A tensor. \n + +*@par Attributes: +*@li peerIndex: The index of the corresponding 'placeholder' node it's connected to. +*@li parentOpType: Op type of original node. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(End) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(peerIndex, Int, 0) + .ATTR(parentOpType, String, "") + .OP_END_FACTORY_REG(End) + +/** +*@brief Operations for writing summary data, for use in analysis and visualization. + +*@par Inputs: +* One input: +*x: Collections of summary data. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(Summary) + .INPUT(x, TensorType::ALL()) + .OP_END_FACTORY_REG(Summary) + +/** +*@brief Returns the shape of a tensor. \n + +*@par Inputs: +*x: A tensor. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n + +*@par Outputs: +*y: A tensor. The shape of the input tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Size. +*/ +REG_OP(Shape) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(Shape) + +/** +*@brief Returns shape of tensors. \n + +*@par Inputs: +*x: A list of input tensors. It's a dynamic input. \n + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n + +*@par Outputs: +*y: A list of tensors with the same length as the input list of tensors. +It's a dynamic output. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ShapeN. +*/ +REG_OP(ShapeN) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(ShapeN) + +/** +*@brief Creates a tensor with the given "shape" and "dtype". \n + +*@par Inputs: +*shape: The shape of the output tensor. \n + +*@par Attributes: +*@li dtype: Optional. The data type of the output tensor. Defaults to "int32". +*@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n + +*@par Outputs: +*y: A tensor. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Empty. +*/ +REG_OP(Empty) + .INPUT(shape, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Int, DT_INT32) + .ATTR(init, Bool, 0) + .OP_END_FACTORY_REG(Empty) + + +/** +*@brief Returns locations of nonzero / true values in a tensor. \n + +*@par Inputs: +*Including: +*x: A Tensor. Must be one of the following types: +DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, +DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n + +*@par Outputs: +*y: A Tensor of type DT_INT64. \n + +*@attention Constraints: +*Where runs on the Ascend AI CPU, which delivers poor performance.\n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Where. +*/ + +REG_OP(Where) + .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_INT64})) + .OP_END_FACTORY_REG(Where) + +/** +*@brief Change the shape of output according to the attr outShape +* + +*@par Inputs: +*x: A Tensor. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n + +*@par Attributes: +*outShape: The shape of output will be inferred according to the attribute +*/ +REG_OP(TransShape) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(outShape,ListInt ,{}) + .OP_END_FACTORY_REG(TransShape); + +/** +* @brief sort_v2. + +* @par Inputs: +* @li x: An ND tensor of type float16. + +* @par Attributes: + +* @li axis: An optional int. The dimension to sort along. This value defaults to -1. +* @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False. + +* @par Outputs: +* @li y: An ND tensor of type float16. + +* @attention Constraints: +* @li Axis should select the last dim. +* @li When the sorting data is less than 150K, it is recommended to use this tbe ops, + and the descending performance is better than the ascending. +* @li The upper limit of data on Ascend910 is 2000K. +*/ +REG_OP(SortV2) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(axis, Int, -1) + .ATTR(descending, Bool, false) + .OP_END_FACTORY_REG(SortV2) + +/** +* @brief Expand the input tensor to a compatible shape. \n + +* @par Inputs: +* One inputs, including: +* @li x: A Tensor. Must be one of the following types: +* float16, float32, int32, int8 ,uint8. \n +* @li shape: A Tensor to specify the shape that the input tensor expanded to. \n + +* @par Outputs: +* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n + +* @par Third-party framework compatibility +* Compatible with the ONNX operator Expand. +*/ + +REG_OP(Expand) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .OP_END_FACTORY_REG(Expand) + +/** +* @brief Expand the input tensor to a compatible shape. \n + +* @par Inputs: +* One inputs, including: +* @li x: A Tensor. Must be one of the following types: +* float16, float32, int32, int8 ,uint8. \n + +* @par Attributes: +* @li shape: A required listInt to specify the shape that the input tensor expanded to. \n + + +* @par Outputs: +* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n + +* @par Third-party framework compatibility +* Compatible with the ONNX operator Expand. +*/ + +REG_OP(ExpandD) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .REQUIRED_ATTR(shape, ListInt) + .OP_END_FACTORY_REG(ExpandD) +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ diff --git a/tests/st/framework/stub_op_proto/control_flow_ops.cc b/tests/st/framework/stub_op_proto/control_flow_ops.cc new file mode 100644 index 00000000..93873afe --- /dev/null +++ b/tests/st/framework/stub_op_proto/control_flow_ops.cc @@ -0,0 +1,392 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file control_flow_ops.cpp + * \brief + */ +#include "control_flow_ops.h" +#include "./util/common_shape_fns.h" +#include "./util/error_util.h" +#include "util/util.h" + +namespace ge { + +namespace { +graphStatus MergeInferImpl(Operator& op) { + TensorDesc td = op.GetOutputDesc("value_index"); + TensorDesc td_y = op.GetOutputDesc("y"); + td.SetShape(ge::Shape()); + td.SetDataType(DT_INT32); + auto ret = op.UpdateOutputDesc("value_index", td); + if (ret != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + + // check N of "x" >= 1 + size_t in_num = op.GetInputsSize(); + if (in_num < 1) { + string reason = "inputs size[" + std::to_string(in_num) + "] must be greater than or equal to 1"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "input", reason); + return GRAPH_FAILED; + } else if (in_num == 2) { + // Check is loop_merge, order of InferShape: Enter->Merge->NextIteration + // So when processing InferShape on Merge op, shape & datatype of NextIteration op is set as default. + // Therefore, shape & datatype of Merge op should be set as the Enter op. + auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); + auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); + bool not_handle_flag0 = (x0_type == DT_FLOAT) && (x0_dims.size() == 0); + auto x1_type = op.GetDynamicInputDesc("x", 1).GetDataType(); + auto x1_dims = op.GetDynamicInputDesc("x", 1).GetShape().GetDims(); + bool not_handle_flag1 = (x1_type == DT_FLOAT) && (x1_dims.size() == 0); + if ((x0_type != x1_type) && (not_handle_flag0 || not_handle_flag1)) { + if (not_handle_flag0) { + td_y.SetShape(ge::Shape(x1_dims)); + td_y.SetDataType(x1_type); + } else { + td_y.SetShape(ge::Shape(x0_dims)); + td_y.SetDataType(x0_type); + } + (void)op.UpdateOutputDesc("y", td_y); + return GRAPH_SUCCESS; + } + } + + // check "x" be same type + auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); + for (size_t i = 1; i < op.GetInputsSize(); i++) { + auto xi_type = op.GetDynamicInputDesc("x", i).GetDataType(); + if (xi_type != x0_type) { + string reason = "x[0]'s dtype[" + std::to_string(x0_type) + "] must be equal to x[" + std::to_string(i) + + "]'s dtype[" + std::to_string(xi_type) + "]"; + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); + return GRAPH_FAILED; + } + } + + // infer "y" be unknown shape + auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); + bool x0_unknown = (x0_dims.size() == 1) && (x0_dims[0] == 0); + if (x0_unknown) { + Shape unknown_shape(ge::UNKNOWN_SHAPE); + td_y.SetShape(unknown_shape); + td_y.SetDataType(x0_type); + (void)op.UpdateOutputDesc("y", td_y); + return GRAPH_SUCCESS; + } + + // find the input with the max size from all inputs, and set it's data type/shape to the output + std::map size_to_index; + for (size_t i = 0; i < op.GetInputsSize(); i++) { + auto xi_dims = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); + bool xi_unknown = (xi_dims.size() == 1) && (xi_dims[0] == 0); + if (xi_unknown) { + continue; + } + int64_t size = static_cast(GetSizeByDataType(op.GetDynamicInputDesc("x", i).GetDataType())); + if (size < 0) { + continue; + } + + if (!xi_dims.empty()) { + for (auto& dim : xi_dims) { + if (dim <= 0) { + size = -1; + break; + } + if (size != 0 && INT64_MAX / size < dim) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dim", "the dim size is overflow"); + return GRAPH_FAILED; + } + size *= dim; + } + if (size < 0) { + continue; + } + } + + if (size_to_index.count(size) == 0) { + size_to_index[size] = i; + } + } + if (size_to_index.empty()) { + return GRAPH_FAILED; + } + auto index = size_to_index.rbegin()->second; + td_y.SetShape(ge::Shape(op.GetDynamicInputDesc("x", index).GetShape().GetDims())); + td_y.SetDataType(op.GetDynamicInputDesc("x", index).GetDataType()); + (void)op.UpdateOutputDesc("y", td_y); + return GRAPH_SUCCESS; +} + +graphStatus SwitchInferImpl(Operator& op) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + + auto data_desc = op_desc->MutableInputDesc("data"); + auto pred_desc = op_desc->MutableInputDesc("pred"); + auto output_false_desc = op_desc->MutableOutputDesc("output_false"); + auto output_true_desc = op_desc->MutableOutputDesc("output_true"); + + std::vector> data_range; + data_desc->GetShapeRange(data_range); + // check "pred" scalar type be bool + auto pred_dims = pred_desc->GetShape().GetDims(); + if (pred_dims.size() != 0) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "pred dims", "pred should be a scalar"); + return GRAPH_FAILED; + } + DataType pred_type = pred_desc->GetDataType(); + if (pred_type != DT_BOOL) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "pred should be bool type"); + return GRAPH_FAILED; + } + + DataType data_type = data_desc->GetDataType(); + auto data_dims = data_desc->GetShape().GetDims(); + + output_false_desc->SetShapeRange(data_range); + output_true_desc->SetShapeRange(data_range); + output_false_desc->SetShape(GeShape(data_dims)); + output_false_desc->SetOriginShape(GeShape(data_dims)); + output_true_desc->SetShape(GeShape(data_dims)); + output_true_desc->SetOriginShape(GeShape(data_dims)); + output_false_desc->SetDataType(data_type); + output_true_desc->SetDataType(data_type); + + auto context = op.GetInferenceContext(); + std::vector> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); + if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { + ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); + std::vector grad_handle_shape_and_type; + grad_handle_shape_and_type.reserve(1); + grad_handle_shape_and_type.emplace_back(shape_and_type); + + std::vector> shapes_and_types(2); + shapes_and_types[0] = grad_handle_shape_and_type; + shapes_and_types[1] = grad_handle_shape_and_type; + context->SetOutputHandleShapesAndTypes(shapes_and_types); + } + + return GRAPH_SUCCESS; +} + +graphStatus EnterInferImpl(Operator& op) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc_x = op_desc->MutableInputDesc("x"); + auto output_desc_y = op_desc->MutableOutputDesc("y"); + std::vector> x_range; + std::vector> y_range; + input_desc_x->GetShapeRange(x_range); + + auto input_dims = input_desc_x->MutableShape().GetDims(); + DataType input_type = input_desc_x->GetDataType(); + output_desc_y->SetShape(ge::GeShape(input_dims)); + output_desc_y->SetOriginShape(ge::GeShape(input_dims)); + output_desc_y->SetDataType(input_type); + + if (!x_range.empty()) { + output_desc_y->SetShapeRange(x_range); + } + + auto context = op.GetInferenceContext(); + std::vector> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); + if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { + ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); + std::vector grad_handle_shape_and_type; + grad_handle_shape_and_type.reserve(1); + grad_handle_shape_and_type.emplace_back(shape_and_type); + + std::vector> shapes_and_types(1); + shapes_and_types[0] = grad_handle_shape_and_type; + context->SetOutputHandleShapesAndTypes(shapes_and_types); + } + + return GRAPH_SUCCESS; +} + +graphStatus PassThroughInferImpl(Operator& op, const std::string& in_name, const std::string& out_name) { + auto input_dims = op.GetInputDesc(in_name).GetShape().GetDims(); + DataType input_type = op.GetInputDesc(in_name).GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc(out_name); + tensordesc_output.SetShape(ge::Shape(input_dims)); + tensordesc_output.SetDataType(input_type); + (void)op.UpdateOutputDesc(out_name, tensordesc_output); + + return GRAPH_SUCCESS; +} + +graphStatus LoopCondInferImpl(Operator& op) { + auto input_dims = op.GetInputDesc("x").GetShape().GetDims(); + if (input_dims.size() != 0) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "x dims", "x should be a scalar"); + return GRAPH_FAILED; + } + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(ge::Shape(input_dims)); + DataType input_type = op.GetInputDesc("x").GetDataType(); + if (input_type != DT_BOOL) { + GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "x should be bool type"); + return GRAPH_FAILED; + } + tensordesc_output.SetDataType(input_type); + (void)op.UpdateOutputDesc("y", tensordesc_output); + + return GRAPH_SUCCESS; +} + +} // namespace + +IMPLEMT_INFERFUNC(Merge, MergeInfer) { + return MergeInferImpl(op); +} + +INFER_FUNC_REG(Merge, MergeInfer); + +IMPLEMT_INFERFUNC(RefMerge, RefMergeInfer) { + return MergeInferImpl(op); +} + +INFER_FUNC_REG(RefMerge, RefMergeInfer); + +IMPLEMT_INFERFUNC(Switch, SwitchInfer) { + return SwitchInferImpl(op); +} + +INFER_FUNC_REG(Switch, SwitchInfer); + +IMPLEMT_INFERFUNC(RefSwitch, RefSwitchInfer) { + return SwitchInferImpl(op); +} + +INFER_FUNC_REG(RefSwitch, RefSwitchInfer); + +IMPLEMT_INFERFUNC(SwitchN, SwitchNInfer) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(SwitchN, SwitchNInfer); + +IMPLEMT_INFERFUNC(Enter, EnterInfer) { + return EnterInferImpl(op); +} + +INFER_FUNC_REG(Enter, EnterInfer); + +IMPLEMT_INFERFUNC(RefEnter, RefEnterInfer) { + return PassThroughInferImpl(op, "x", "y"); +} + +INFER_FUNC_REG(RefEnter, RefEnterInfer); + +IMPLEMT_INFERFUNC(LoopCond, LoopCondInfer) { + return LoopCondInferImpl(op); +} + +INFER_FUNC_REG(LoopCond, LoopCondInfer); + +IMPLEMT_INFERFUNC(NextIteration, NextIterationInfer) { + return PassThroughInferImpl(op, "x", "y"); +} + +INFER_FUNC_REG(NextIteration, NextIterationInfer); + +IMPLEMT_INFERFUNC(RefNextIteration, RefNextIterationInfer) { + return PassThroughInferImpl(op, "x", "y"); +} + +INFER_FUNC_REG(RefNextIteration, RefNextIterationInfer); + +IMPLEMT_INFERFUNC(Exit, ExitInfer) { + return PassThroughInferImpl(op, "x", "y"); +} + +INFER_FUNC_REG(Exit, ExitInfer); + +IMPLEMT_INFERFUNC(RefExit, RefExitInfer) { + return PassThroughInferImpl(op, "x", "y"); +} + +INFER_FUNC_REG(RefExit, RefExitInfer); + +// ----------------MapIndex------------------- +IMPLEMT_VERIFIER(MapIndex, MapIndexVerify) { + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(MapIndexInferShape) { + OP_LOGI("MapIndex", "infer shape begin---"); + auto x_shape = op.GetInputDesc("x").GetShape().GetDims(); + if (x_shape.empty()) { + OP_LOGE(op.GetName().c_str(), "x_shape is empty"); + OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "x_shape is empty"); + return GRAPH_FAILED; + } + int64_t x_length = x_shape[0]; + + auto data_seq_shape = op.GetInputDesc("data_seq").GetShape().GetDims(); + if (data_seq_shape.empty()) { + OP_LOGE(op.GetName().c_str(), "data_seq_shape is empty"); + OpsOneInputShapeErrReport(op.GetName().c_str(), "data_seq", "data_seq_shape is empty"); + return GRAPH_FAILED; + } + int64_t data_seq_length = data_seq_shape[0]; + + if (x_length > 8 || x_length == 0) { + OP_LOGE(op.GetName().c_str(), "the length of x should be less than or equal to 8"); + OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "the length of x should be less than or equal to 8 and not 0"); + return GRAPH_FAILED; + } + + if (data_seq_length % x_length != 0) { + OP_LOGE(op.GetName().c_str(), "the length of data_seq must be multiple of the length of x"); + OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", + "the length of data_seq must be multiple of the length of x"); + return GRAPH_FAILED; + } + + if (data_seq_length / x_length > 100) { + OP_LOGE(op.GetName().c_str(), "data_seq_length / x_length should be be less than or equal to 100"); + OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", + "data_seq_length / x_length should be be less than or equal to 100"); + return GRAPH_FAILED; + } + + auto level_index_shape = op.GetInputDesc("level_index").GetShape().GetDims(); + if (!level_index_shape.empty()) { + int64_t level_index_length = level_index_shape[0]; + if (level_index_length != (data_seq_length / x_length)) { + OP_LOGE(op.GetName().c_str(), + "the length of level_index must be equal to " + "the length of data_seq divided by the length of x"); + OpsOneInputShapeErrReport(op.GetName().c_str(), "level_index", + "the length of level_index must be equal to " + "the length of data_seq divided by the length of x"); + return GRAPH_FAILED; + } + } + + TensorDesc y_desc = op.GetOutputDesc("y"); + y_desc.SetShape(ge::Shape()); + y_desc.SetDataType(ge::DT_INT32); + + (void)op.UpdateOutputDesc("y", y_desc); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(MapIndex, MapIndexInferShape); +VERIFY_FUNC_REG(MapIndex, MapIndexVerify); + +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/control_flow_ops.h b/tests/st/framework/stub_op_proto/control_flow_ops.h new file mode 100644 index 00000000..e57932c8 --- /dev/null +++ b/tests/st/framework/stub_op_proto/control_flow_ops.h @@ -0,0 +1,407 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file control_flow_ops.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ +#define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { + +/** + *@brief Forwards the value of an available tensor from input "x" to output "y". + * Merge waits for at least one of the input tensors to become available. + * It is usually combined with Switch to implement branching. + * Merge forwards the first tensor to become available to output "y", + * and sets "value_index" the index of the tensor in inputs . \n + + *@par Inputs: + *x: The input tensors, one of which will become available. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n + + *@par Outputs: + *@li y: The available tensor. Has the same type as "x". + *@li value_index: A scalar of type int32, for the index of the chosen input + * tensor . \n + + *@see Switch() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator Merge. + */ +REG_OP(Merge) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(value_index, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(Merge) + +/** + *@brief Forwards the value of an available tensor from input "x" to output "y". + * Merge waits for at least one of the input tensors to become available. + * It is usually combined with Switch to implement branching. + * Merge forwards the first tensor to become available to output "y", + * and sets "value_index" the index of the tensor in inputs . \n + + *@par Inputs: + *x: The input tensors, one of which will become available. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n + + *@par Outputs: + *@li y: The available tensor. Has the same type as "x". + *@li value_index: A scalar of type int32, for the index of the chosen input + * tensor . \n + + *@see Switch() | Merge() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator RefMerge. + */ +REG_OP(RefMerge) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(value_index, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(RefMerge) + +/** + *@brief Forwards "data" to the output port determined by "pred". + * If "pred" is "true", the data input is forwarded to "output_true". + * Otherwise, the data is forwarded to "output_false" . \n + + *@par Inputs: + *@li data: The tensor to be forwarded. \ n + * Must be one of the following types: float16, float32, float64, + * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. + *@li pred: A boolean scalar. The output port that will receive data . \n + + *@par Outputs: + *@li output_false: If "pred" is "false", data will be forwarded to this output. + * Has the same type as "data". + *@li output_true: If "pred" is "true", data will be forwarded to this output. + * Has the same type as "data" . \n + + *@see Merge() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator Switch. + */ +REG_OP(Switch) + .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .INPUT(pred, TensorType({DT_BOOL})) + .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(Switch) + +/** + *@brief Forwards "data" to the output port determined by "pred". + * If "pred" is "true", the data input is forwarded to "output_true". + * Otherwise, the data is forwarded to "output_false" . \n + + *@par Inputs: + *@li data: The ref tensor to be forwarded. + * Must be one of the following types: float16, float32, float64, + * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. + *@li pred: A boolean scalar. The output port that will receive data . \n + + *@par Outputs: + *@li output_false: If "pred" is "false", data will be forwarded to this output. + * Has the same type as "data". + *@li output_true: If "pred" is "true", data will be forwarded to this output. + * Has the same type as "data" . \n + + *@see Merge() | Switch() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator RefSwitch. + */ +REG_OP(RefSwitch) + .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .INPUT(pred, TensorType({DT_BOOL})) + .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(RefSwitch) + +/** + *@brief Forwards "data" to the output port determined by "pred_value" . \n + + *@par Inputs: + *@li data: The tensor to be forwarded. \ n + * Must be one of the following types: float16, float32, float64, + * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. + *@li pred_value: A int64 tensor which determines the output port that will receive data . \n + + *@par Outputs: + *output: The output tensors, one of which will become available. + * Has the same type as "data". + */ +REG_OP(SwitchN) + .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .INPUT(pred_value, TensorType({DT_INT64})) + .DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(SwitchN) + +/** + *@brief Creates or finds a child frame, and makes "x" available to the child + * frame. This op is used together with Exit to create loops in the graph. + * The Executor uses the unique "frame_name" to identify frames. + * If "is_constant" is "true", output "y" is a constant in the child + * frame; otherwise it may be changed in the child frame . \n + + *@par Inputs: + *x: The tensor to be made available to the child frame. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Attributes: + *@li frame_name: A required string. The name of the child frame. + *@li is_constant: A required bool. If true, the output is constant in + * the child frame . \n + + *@par Outputs: + *y: A Tensor. Has the same type as "x" . \n + + *@see Exit() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator Enter. + */ +REG_OP(Enter) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .REQUIRED_ATTR(frame_name, String) + .REQUIRED_ATTR(is_constant, Bool) + .OP_END_FACTORY_REG(Enter) + +/** + *@brief Creates or finds a child frame, and makes "x" available to the child + * frame. This op is used together with Exit to create loops in the graph. + * The Executor uses the unique "frame_name" to identify frames. + * If "is_constant" is "true", output "y" is a constant in the child + * frame; otherwise it may be changed in the child frame . \n + + *@par Inputs: + *x: The tensor to be made available to the child frame. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Attributes: + *@li frame_name: A required string. The name of the child frame. + *@li is_constant: A required bool. If true, the output is constant in + * the child frame . \n + + *@par Outputs: + *y: A tensor. Has the same type as "x" . \n + + *@see Exit() | Enter() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator RefEnter. + */ +REG_OP(RefEnter) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .REQUIRED_ATTR(frame_name, String) + .REQUIRED_ATTR(is_constant, Bool) + .OP_END_FACTORY_REG(RefEnter) + +/** + *@brief Forwards the input to the output. This op represents the loop + * termination condition . \n + + *@par Inputs: + *x: A boolean scalar. The condition of the Switch op . \n + + *@par Outputs: + *y: The tensor "x" . \n + + *@see Switch() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator LoopCond. + */ +REG_OP(LoopCond) + .INPUT(x, TensorType({DT_BOOL})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(LoopCond) + +/** + *@brief Makes the input available to the next iteration . \n + + *@par Inputs: + *x: The tensor to be made available to the next iteration. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Outputs: + *y: A Tensor. Has the same type as "x" . \n + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator NextIteration. + */ +REG_OP(NextIteration) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(NextIteration) + +/** + *@brief Makes the input available to the next iteration . \n + + *@par Inputs: + *x: The tensor to be made available to the next iteration. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Outputs: + *y: A tensor. Has the same type as "x" . \n + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator RefNextIteration. + */ +REG_OP(RefNextIteration) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(RefNextIteration) + +/** + *@brief Exits the current frame to its parent frame . \n + + *@par Inputs: + *x: The tensor to be made available to the parent frame. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Outputs: + *y: A Tensor. Has the same type as "x" . \n + + *@see Enter() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator Exit. + */ +REG_OP(Exit) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(Exit) + +/** + *@brief Exits the current frame to its parent frame . \n + + *@par Inputs: + *x: The tensor to be made available to the parent frame. + * Must be one of the following types: float16, float32, float64, int8, + * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n + + *@par Outputs: + *y: A tensor. Has the same type as "x" . \n + + *@see Enter() | Exit() + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator RefExit. + */ +REG_OP(RefExit) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, + DT_UINT64, DT_BOOL})) + .OP_END_FACTORY_REG(RefExit) + +/** + *@brief Only useful as a placeholder for control edges. + * It is similar to a no-op that always produces a live control output + * even when some control inputs are dead . \n + + *@par Third-party framework compatibility + *@Compatible with the TensorFlow operator ControlTrigger. + */ +REG_OP(ControlTrigger) + .OP_END_FACTORY_REG(ControlTrigger) + +/** +*@brief Returns index of shape in the map. + +*@par Inputs: +* Three inputs, including: +*@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8. +*@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried. +*@li level_index: One dimensional tensore of type int32, specifying secondary index. \n + +*@par Outputs: +*@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map. +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ +REG_OP(MapIndex) + .INPUT(x, TensorType({DT_INT32})) + .INPUT(data_seq, TensorType({DT_INT32})) + .OPTIONAL_INPUT(level_index, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(MapIndex) +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ diff --git a/tests/st/framework/stub_op_proto/elewise_calculation_ops.cc b/tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file mode 100644 index 00000000..e3be7c69 --- /dev/null +++ b/tests/st/framework/stub_op_proto/elewise_calculation_ops.cc @@ -0,0 +1,4633 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file elewise_calculation_ops.cpp + * \brief + */ +#include "elewise_calculation_ops.h" +#include +#include +#include "util/util.h" +#include "util/op_log.h" +#include "./util/error_util.h" +#include "graph/utils/node_utils.h" + +namespace ge { +bool BroadCastTwoShape(const Operator& op, const ge::Shape& shape_x, const ge::Shape& shape_y, + std::vector& dim_out) { + std::vector dim_x = shape_x.GetDims(); + std::vector dim_y = shape_y.GetDims(); + // exchange them + if (dim_x.size() < dim_y.size()) { + std::vector dim_tmp = dim_x; + dim_x = dim_y; + dim_y = dim_tmp; + } + + // expand smalll shape + if (dim_x.size() != dim_y.size()) { + int dec = dim_x.size() - dim_y.size(); + for (int i = 0; i < dec; i++) { + dim_y.insert(dim_y.begin(), (int64_t)1); + } + } + + // set out dims + for (size_t i = 0; i < dim_x.size(); i++) { + if ((dim_x[i] != dim_y[i]) && (dim_x[i] != 1) && (dim_y[i] != 1)) { + OP_LOGE(op.GetName().c_str(), "The %s's dimensions does not match the broadcast rule(%lu %lu).", + op.GetName().c_str(), dim_x[i], dim_y[i]); + return false; + } + + int64_t dim = dim_x[i] > dim_y[i] ? dim_x[i] : dim_y[i]; + dim_out.push_back(dim); + } + return true; +} + +bool InferShapeForMaximumAndMinimum(Operator& op) { + auto attr_grad_x = false; + auto attr_grad_y = false; + if (op.GetAttr("grad_x", attr_grad_x) == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "get attr grad_x failed"); + } + if (op.GetAttr("grad_y", attr_grad_y) == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "get attr grad_y failed"); + } + if (attr_grad_x == false && attr_grad_y == false) { + OP_LOGE(op.GetName().c_str(), "the grad_x and grad_y is not support all false"); + return false; + } + if (attr_grad_x) { + if(!OneInOneOutDynamicInfer(op,"x1",{"y1"})){ + return false; + } + } + if (attr_grad_y) { + if(!OneInOneOutDynamicInfer(op,"x2",{"y2"})){ + return false; + } + } + + return true; +} + +IMPLEMT_COMMON_INFERFUNC(TwoInOneOutCommonInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(OneInOneOutCommonInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +// ----------------MaximumGrad------------------- +IMPLEMT_COMMON_INFERFUNC(MaximumGradInferShape) { + if (InferShapeForMaximumAndMinimum(op)) { + return GRAPH_SUCCESS; + } + + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(MaximumGrad, MaximumGradInferShape); +// ----------------MaximumGrad End------------------- + +// ----------------MinimumGrad------------------- +IMPLEMT_COMMON_INFERFUNC(MinimumGradInferShape) { + if (InferShapeForMaximumAndMinimum(op)) { + return GRAPH_SUCCESS; + } + + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(MinimumGrad, MinimumGradInferShape); +// ----------------MinimumGrad End------------------- + +// ----------------------Add-------------------------- +IMPLEMT_VERIFIER(Add, AddVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AddInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Add, AddInferShape); +VERIFY_FUNC_REG(Add, AddVerify); +// ---------------------Add END------------------------ + +// ----------------------FusedMulAdd-------------------------- +IMPLEMT_VERIFIER(FusedMulAdd, FusedMulAddVerify) { + DataType input_type_x1 = op.GetInputDesc("x1").GetDataType(); + DataType input_type_x2 = op.GetInputDesc("x2").GetDataType(); + DataType input_type_x3 = op.GetInputDesc("x3").GetDataType(); + if (input_type_x1 != input_type_x2) { + OpsTwoInputDtypeErrReport(op.GetName(), "x1", "x2", ConcatString(input_type_x1), ConcatString(input_type_x2)); + OP_LOGE(op.GetName().c_str(), "The %s op dtype is not same, type1:%d, type2:%d", op.GetName().c_str(), + input_type_x1, input_type_x2); + return false; + } + + if (input_type_x2 != input_type_x3) { + OpsTwoInputDtypeErrReport(op.GetName(), "x2", "x3", ConcatString(input_type_x2), ConcatString(input_type_x3)); + OP_LOGE(op.GetName().c_str(), "The %s op dtype is not same, type2:%d, type3:%d", op.GetName().c_str(), + input_type_x2, input_type_x3); + return false; + } + + return true; +} + +IMPLEMT_COMMON_INFERFUNC(FusedMulAddInferShape) { + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + + ge::Shape shape1 = op.GetInputDesc("x1").GetShape(); + ge::Shape shape2 = op.GetInputDesc("x2").GetShape(); + std::vector vec_mul_out; + if (!BroadCastTwoShape(op, shape1, shape2, vec_mul_out)) { + return GRAPH_FAILED; + } + + ge::Shape shape_mul_out = ge::Shape(vec_mul_out); + ge::Shape shape3 = op.GetInputDesc("x3").GetShape(); + std::vector vec_add_out; + if (!BroadCastTwoShape(op, shape_mul_out, shape3, vec_add_out)) { + return GRAPH_FAILED; + } + + ge::Shape shape_add_out = ge::Shape(vec_add_out); + TensorDesc y_desc = op.GetOutputDesc("y"); + y_desc.SetShape(shape_add_out); + DataType dtype_input = op.GetInputDesc("x1").GetDataType(); + y_desc.SetDataType(dtype_input); + (void)op.UpdateOutputDesc("y", y_desc); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FusedMulAdd, FusedMulAddInferShape); +VERIFY_FUNC_REG(FusedMulAdd, FusedMulAddVerify); +// ---------------------FusedMulAdd END------------------------ + +// ---------------------AddV2-------------------------- +IMPLEMT_VERIFIER(AddV2, AddV2Verify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AddV2, AddInferShape); +VERIFY_FUNC_REG(AddV2, AddV2Verify); +// -------------------AddV2 END---------------------- + +// ----------------Cast------------------- +IMPLEMT_COMMON_INFERFUNC(CastInferShape) { + // get input desc + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc = op_info->MutableInputDesc("x"); + vector input_shape = input_desc->MutableShape().GetDims(); + + auto output_desc = op_info->MutableOutputDesc("y"); + if (IsUnknown(input_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(input_shape, input_range); + + output_desc->SetShape(GeShape(input_shape)); + output_desc->SetOriginShape(GeShape(input_shape)); + output_desc->SetShapeRange(input_range); + } else { + output_desc->SetShape(GeShape(input_shape)); + } + int type; + if (op.GetAttr("dst_type", type) == GRAPH_SUCCESS) { + output_desc->SetDataType((ge::DataType)type); + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Cast, CastInferShape); +// --------------Cast END----------------- + +// ---------------------GreaterEqual----------------------- +IMPLEMT_VERIFIER(GreaterEqual, GreaterEqualVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(GreaterEqualInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(GreaterEqual, GreaterEqualInferShape); +VERIFY_FUNC_REG(GreaterEqual, GreaterEqualVerify); +// ------------------GreaterEqual END------------------- + +// --------------------Less-------------------- +IMPLEMT_VERIFIER(Less, LessVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(LessInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Less, LessInferShape); +VERIFY_FUNC_REG(Less, LessVerify); +// -----------------Less END----------------------- + +// ------------------RealDiv--------------------- +IMPLEMT_VERIFIER(RealDiv, RealDivVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + + +COMMON_INFER_FUNC_REG(RealDiv, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(RealDiv, RealDivVerify); +// ----------------RealDiv END------------------ + +// ----------------Sqrt Op Begin------------ +COMMON_INFER_FUNC_REG(Sqrt, OneInOneOutCommonInferShape); +// ----------------Sqrt Op End--------------- + +// ----------------Maximum-------------------- +IMPLEMT_VERIFIER(Maximum, MaximumVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Maximum, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(Maximum, MaximumVerify); +// --------------Maximum END------------------ + +// ----------------Minimum-------------------- +IMPLEMT_VERIFIER(Minimum, MinimumVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Minimum, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(Minimum, MinimumVerify); +// -----------------Minimum END----------------- + +// ----------------Reciprocal------------------- +COMMON_INFER_FUNC_REG(Reciprocal, OneInOneOutCommonInferShape); +// ---------------Reciprocal END----------------- + +// -------------------Sub---------------------- +IMPLEMT_COMMON_INFERFUNC(SubInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Sub, SubInferShape); +// -----------------Sub END----------------- + +// ----------------Abs------------------- +IMPLEMT_COMMON_INFERFUNC(AbsInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Abs, AbsInferShape); +// --------------Abs END----------------- + +// ----------------Sign------------------- +IMPLEMT_COMMON_INFERFUNC(SignInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Sign, SignInferShape); +// ---------------Sign END----------------- + +// ----------------SquaredDifference------------------- +IMPLEMT_COMMON_INFERFUNC(SquaredDifferenceInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(SquaredDifference, SquaredDifferenceInferShape); +// ----------------SquaredDifference END--------------- + +// ------------------Div--------------------- +IMPLEMT_VERIFIER(Div, DivVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Div, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(Div, DivVerify); +// -----------------Div END------------------ + +// -------------------Equal-------------------- +IMPLEMT_VERIFIER(Equal, EqualVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(EqualInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Equal, EqualInferShape); +VERIFY_FUNC_REG(Equal, EqualVerify); +// ------------------Equal END-------------------- + +// ----------------Exp------------------- +COMMON_INFER_FUNC_REG(Exp, OneInOneOutCommonInferShape); +// ----------------Exp END------------------- + +// ----------------------Inv---------------------- +IMPLEMT_COMMON_INFERFUNC(InvInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Inv, InvInferShape); +// ----------------------Inv END---------------------- + +// ----------------------InvGrad---------------------- +IMPLEMT_VERIFIER(InvGrad, InvGradVerify) { + if (!CheckTwoInputDtypeSame(op, "x", "grad")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(InvGradInferShape) { + bool is_dynamic_output = true; + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x", "grad", "y", is_dynamic_output)) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(InvGrad, InvGradInferShape); +VERIFY_FUNC_REG(InvGrad, InvGradVerify); +// ----------------------InvGrad END---------------------- + +// -------------------LessEqual--------------------- +IMPLEMT_VERIFIER(LessEqual, LessEqualVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(LessEqualInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + return GRAPH_SUCCESS; +} + + +COMMON_INFER_FUNC_REG(LessEqual, LessEqualInferShape); +VERIFY_FUNC_REG(LessEqual, LessEqualVerify); +// --------------------LessEqual END----------------------- + +// ----------------Log1p------------------- +COMMON_INFER_FUNC_REG(Log1p, OneInOneOutCommonInferShape); +// --------------Log1p END----------------- + +// -------------------NotEqual-------------------- +IMPLEMT_VERIFIER(NotEqual, NotEqualVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(NotEqualInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(NotEqual, NotEqualInferShape); +VERIFY_FUNC_REG(NotEqual, NotEqualVerify); +// ------------------NotEqual END-------------------- + +// ----------------Neg------------------- +COMMON_INFER_FUNC_REG(Neg, OneInOneOutCommonInferShape); +// ---------------Neg EDN----------------- + +// ------------------DivNoNan----------------------- +IMPLEMT_VERIFIER(DivNoNan, DivNoNanVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(DivNoNanInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y",is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(DivNoNan, DivNoNanInferShape); +VERIFY_FUNC_REG(DivNoNan, DivNoNanVerify); +// --------------DivNoNan END---------------------- + +// ----------------Invert------------------- +IMPLEMT_COMMON_INFERFUNC(InvertInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Invert, InvertInferShape); +// ----------------Invert END------------------- + +// ---------------OnesLike----------------- +IMPLEMT_COMMON_INFERFUNC(OnesLikeInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})){ + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(OnesLike, OnesLikeInferShape); +// ----------------OnesLike END----------------- + +// ----------------ReciprocalGrad------------------- +IMPLEMT_VERIFIER(ReciprocalGrad, ReciprocalGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(ReciprocalGradInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z")) { + return GRAPH_FAILED; + } + + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("z")->MutableShape().GetDims(); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "y", "dy", "z")) { + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ReciprocalGrad, ReciprocalGradInferShape); +VERIFY_FUNC_REG(ReciprocalGrad, ReciprocalGradVerify); +// --------------ReciprocalGrad END----------------- + +// ----------------Square Op Begin----------------- +COMMON_INFER_FUNC_REG(Square, OneInOneOutCommonInferShape); +// ----------------Square Op End------------------- + +// ----------------RsqrtGrad---------------------- +IMPLEMT_VERIFIER(RsqrtGrad, RsqrtGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(RsqrtGradInferShape) { + Shape y_shape = op.GetInputDesc("y").GetShape(); + DataType input_dtype = op.GetInputDesc("y").GetDataType(); + std::vector> shape_range_y; + auto status = op.GetInputDesc("y").GetShapeRange(shape_range_y); + if (status != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + TensorDesc td = op.GetOutputDesc("z"); + td.SetShape(ge::Shape(y_shape)); + td.SetDataType(input_dtype); + td.SetShapeRange(shape_range_y); + (void)op.UpdateOutputDesc("z", td); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(RsqrtGrad, RsqrtGradInferShape); +VERIFY_FUNC_REG(RsqrtGrad, RsqrtGradVerify); +// ----------------RsqrtGrad END---------------------- + +// --------------------ClipByValue----------------------- +IMPLEMT_VERIFIER(ClipByValue, ClipByValueVerify) { + if (!CheckTwoInputDtypeSame(op, "x", "clip_value_min") || !CheckTwoInputDtypeSame(op, "x", "clip_value_max")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(ClipByValueInferShape) { + Shape input_shape = op.GetInputDesc("x").GetShape(); + Shape input_ori_shape = op.GetInputDesc("x").GetOriginShape(); + DataType input_dtype = op.GetInputDesc("x").GetDataType(); + TensorDesc td = op.GetOutputDesc("y"); + td.SetShape(Shape(input_shape)); + td.SetOriginShape(Shape(input_ori_shape)); + td.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ClipByValue, ClipByValueInferShape); +VERIFY_FUNC_REG(ClipByValue, ClipByValueVerify); +// -------------------ClipByValue END------------------- + +// -------------------LogicalOr-------------------- +IMPLEMT_VERIFIER(LogicalOr, LogicalOrVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(LogicalOr, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(LogicalOr, LogicalOrVerify); +// ----------------LogicalOr END-------------------- + +// ----------------Rsqrt------------------- +COMMON_INFER_FUNC_REG(Rsqrt, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// ----------------Rsqrt------------------- + +// ----------------Acos------------------- +COMMON_INFER_FUNC_REG(Acos, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// --------------Acos END----------------- + +// ----------------BesselI0e------------------- +COMMON_INFER_FUNC_REG(BesselI0e, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// --------------BesselI0e END----------------- + +// ----------------BesselI1e------------------- +COMMON_INFER_FUNC_REG(BesselI1e, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// --------------BesselI1e END----------------- + +// ------------------Mul -------------------- +IMPLEMT_VERIFIER(Mul, MulVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Mul, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(Mul, MulVerify); +// ----------------Mul END-------------------- + +// ----------------SqrtGrad Op Begin----------------- +IMPLEMT_VERIFIER(SqrtGrad, SqrtGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(SqrtGradInferShape) { + Shape shape_x = op.GetInputDesc("y").GetShape(); + DataType input_dtype = op.GetInputDesc("y").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("z"); + std::vector> shape_range_x; + op.GetInputDesc("y").GetShapeRange(shape_range_x); + tensordesc_output.SetShape(shape_x); + tensordesc_output.SetDataType(input_dtype); + tensordesc_output.SetShapeRange(shape_range_x); + if (op.UpdateOutputDesc("z", tensordesc_output) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "UpdateOutputDesc run failed. Check whether the names of outputs are matched."); + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(SqrtGrad, SqrtGradInferShape); +VERIFY_FUNC_REG(SqrtGrad, SqrtGradVerify); +// ----------------SqrtGrad Op End------------------- + +// ----------------Log------------------- +IMPLEMT_COMMON_INFERFUNC(LogInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Log, LogInferShape); +// ----------------Log END------------------- + +// ----------------Assign------------------- +IMPLEMT_VERIFIER(Assign, AssignVerify) { + if (!CheckTwoInputDtypeSame(op, "ref", "value")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AssignInferShape) { + if (OneInOneOutDynamicInfer(op, "value", {"ref"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; + +} + +COMMON_INFER_FUNC_REG(Assign, AssignInferShape); +VERIFY_FUNC_REG(Assign, AssignVerify); +// ----------------Assign END------------------- + +// ----------------AddN------------------- +int64_t GetAddNConstValue(const ge::Operator& op) { + int64_t tensor_num; + if (ge::GRAPH_SUCCESS != op.GetAttr("N", tensor_num)) { + OpsGetAttrErrReport(op.GetName(), "N"); + OP_LOGE(op.GetName().c_str(), "The add_n op GetOpAttr failed!"); + } + return tensor_num; +} + +int64_t AddNInferClassify(ge::Operator& op, int64_t &tensor_num) { + const int64_t infer_condition_one_one = 11; + const int64_t infer_condition_one_two = 12; + const int64_t infer_condition_two = 2; + const int64_t infer_condition_three = 3; + + int64_t empty_num = 0; + int64_t static_num = 0; + int64_t dynamic_shape_num = 0; + int64_t dynamic_dim_num = 0; + + for (int64_t i = 0; i < tensor_num; i++) { + vector tempVector = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); + if (tempVector.empty()) { + empty_num++; + } else if (std::find(tempVector.begin(), tempVector.end(), -1) != tempVector.end()) { + dynamic_shape_num++; + } else if (std::find(tempVector.begin(), tempVector.end(), -2) != tempVector.end()) { + dynamic_dim_num++; + } else { + static_num++; + } + } + if (tensor_num == empty_num + dynamic_dim_num) { + if (tensor_num == empty_num) { + return infer_condition_one_one; + } else { + return infer_condition_one_two; + } + } else if (tensor_num == static_num || tensor_num == empty_num + static_num || tensor_num == static_num + + dynamic_dim_num || tensor_num == empty_num + static_num + dynamic_dim_num) { + return infer_condition_two; + } else { + return infer_condition_three; + } +} + +IMPLEMT_COMMON_INFERFUNC(AddNInferShape) { + /* + add_n has four type inputs: + 1.empty 2.static shape 3.-1 4.-2 + The combinations bring 15 scenes, and the 15 scenes can be classify into 4 categories: + 1.input with no range and output no need range, and it can be divided half: + 1.1 all input is empty + 1.2 input only contains empty and -2 shape + 2.input contains static shape and with no -1 shape + 3.input contains -1 shape + */ + int64_t tensor_num = GetAddNConstValue(op); + int64_t infer_classify = AddNInferClassify(op, tensor_num); + // condition 1: all input shape is empty + if (infer_classify == 11) { + std::vector shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); + DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType(); + TensorDesc y_desc = op.GetOutputDesc("y"); + y_desc.SetShape(Shape(shape_vector)); + y_desc.SetDataType(x_dtype); + (void)op.UpdateOutputDesc("y", y_desc); + // condition 2: all input is -2 or only empty and -2 + } else if (infer_classify == 12) { + std::vector shape_vector = {-2}; + DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType(); + TensorDesc y_desc = op.GetOutputDesc("y"); + y_desc.SetShape(Shape(shape_vector)); + y_desc.SetDataType(x_dtype); + (void)op.UpdateOutputDesc("y", y_desc); + // condition 3: contains static shape and no -1 shape + } else if (infer_classify == 2) { + DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType(); + std::vector shape_vector = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); + for (int64_t i = 0; i < tensor_num; i++) { + std::vector temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); + if (!shape_vector.empty() && !IsUnknownRankShape(shape_vector)) { + shape_vector = temp_vector; + break; + } + } + TensorDesc y_desc = op.GetOutputDesc("y"); + y_desc.SetShape(ge::Shape(shape_vector)); + y_desc.SetDataType(x_dtype); + std::vector> out_range; + MakeUpShapeRange(shape_vector, out_range); + y_desc.SetShapeRange(out_range); + (void)op.UpdateOutputDesc("y", y_desc); + // condition 4: contains -1 shape, range need to choose the intersection + } else { + Shape out_shape = op.GetDynamicInputDesc("x", 0).GetShape(); + DataType x_dtype = op.GetDynamicInputDesc("x", 0).GetDataType(); + std::vector out_vector; + std::vector> out_range; + // Init the output shape and range + for (int64_t i = 0; i < tensor_num; i++) { + std::vector temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); + if (!temp_vector.empty() && !IsUnknownRankShape(temp_vector)) { + out_vector = temp_vector; + op.GetDynamicInputDesc("x", i).GetShapeRange(out_range); + MakeUpShapeRange(out_vector, out_range); + break; + } + } + // compute the shape dims and range intersection + for (int64_t i = 0; i < tensor_num; i++) { + std::vector temp_vector = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); + if (temp_vector.empty() || IsUnknownRankShape(temp_vector)) { + continue; + } + std::vector> temp_range; + op.GetDynamicInputDesc("x", i).GetShapeRange(temp_range); + MakeUpShapeRange(temp_vector, temp_range); + for (size_t j = 0; j < temp_vector.size(); j++) { + // two condition: const == const; const > -1 + if (temp_vector[j] >= out_vector[j]) { + out_vector[j] = temp_vector[j]; + // update range: left choose the max value + if (temp_range[j].first >= out_range[j].first) { + out_range[j].first = temp_range[j].first; + } + // update range: right choose the miner value but when it was > 0 + if ((temp_range[j].second <= out_range[j].second && temp_range[j].second > 0) || + (out_range[j].second == -1 && temp_range[j].second != -1)) { + out_range[j].second = temp_range[j].second; + } + } + } + } + TensorDesc y_desc = op.GetOutputDesc("y"); + out_shape = Shape(out_vector); + y_desc.SetShape(out_shape); + y_desc.SetDataType(x_dtype); + y_desc.SetShapeRange(out_range); + (void)op.UpdateOutputDesc("y", y_desc); + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AddN, AddNInferShape); +// ----------------AddN END------------------- + +// ----------------AssignAdd------------------- +IMPLEMT_VERIFIER(AssignAdd, AssignAddVerify) { + if (!CheckTwoInputDtypeSame(op, "ref", "value")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AssignAddInferShape) { + if (TwoInOneOutDynamicInferNoBroadcast(op, "ref", "value", {"ref"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(AssignAdd, AssignAddInferShape); +VERIFY_FUNC_REG(AssignAdd, AssignAddVerify); +// ----------------AssignAdd END------------------- + +// ----------------AssignSub------------------- +IMPLEMT_VERIFIER(AssignSub, AssignSubVerify) { + if (!CheckTwoInputDtypeSame(op, "var", "value")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AssignSubInferShape) { + if (TwoInOneOutDynamicInferNoBroadcast(op, "var", "value", {"var"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(AssignSub, AssignSubInferShape); +VERIFY_FUNC_REG(AssignSub, AssignSubVerify); +// ----------------AssignSub END------------------- + +// ----------------Atanh------------------- +IMPLEMT_COMMON_INFERFUNC(AtanhInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Atanh, AtanhInferShape); +// --------------Atanh END----------------- + +// ----------------Atan-------------------- +IMPLEMT_COMMON_INFERFUNC(AtanInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Atan, AtanInferShape); +// --------------Atan END----------------- + +// ----------------Atan2------------------- +IMPLEMT_VERIFIER(Atan2, Atan2Verify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(Atan2InferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + + +COMMON_INFER_FUNC_REG(Atan2, Atan2InferShape); +VERIFY_FUNC_REG(Atan2, Atan2Verify); +// --------------Atan2 END----------------- + +// --------------AcosGrad---------------- +IMPLEMT_VERIFIER(AcosGrad, AcosGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(AcosGrad, AcosGradVerify); + +IMPLEMT_COMMON_INFERFUNC(AcosGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AcosGrad, AcosGradInferShape); +// ------------AcosGrad END---------------- + +// ----------------AcoshGrad------------------- +IMPLEMT_VERIFIER(AcoshGrad, AcoshGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(AcoshGrad, AcoshGradVerify); + +IMPLEMT_COMMON_INFERFUNC(AcoshGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AcoshGrad, AcoshGradInferShape); +// --------------AcoshGrad END----------------- + +// ----------------AtanGrad------------------- +IMPLEMT_COMMON_INFERFUNC(AtanGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AtanGrad, AtanGradInferShape); +// --------------AtanGrad END----------------- + +// -------------------ApproximateEqual---------------------- +IMPLEMT_VERIFIER(ApproximateEqual, ApproximateEqualVerify) { + float tolerance_data; + if (ge::GRAPH_SUCCESS != op.GetAttr("tolerance", tolerance_data)) { + OpsGetAttrErrReport(op.GetName(), "tolerance"); + OP_LOGE(op.GetName().c_str(), "GetOpAttr failed of ApproximateEqual!"); + return GRAPH_FAILED; + } + if (tolerance_data < 0) { + OpsAttrValueErrReport(op.GetName(), "tolerance", ">= 0", ConcatString(tolerance_data)); + OP_LOGE(op.GetName().c_str(), "tolerance should >= 0!"); + return GRAPH_FAILED; + } + + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(ApproximateEqualInferShape) { + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(op.GetInputDesc("x1").GetShape()); + tensordesc_output.SetDataType(DT_BOOL); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ApproximateEqual, ApproximateEqualInferShape); +VERIFY_FUNC_REG(ApproximateEqual, ApproximateEqualVerify); +// -------------------ApproximateEqual------------------------- + +// --------------------AccumulateNV2-------------------------- +bool CheckInputSize(const Operator& op) { + OP_LOGI(op.GetName().c_str(), "The op begin verify"); + auto input_size = op.GetInputsSize(); + if (input_size == 0) { + OpsMissInputErrReport(op.GetName(), "x"); + OP_LOGE(op.GetName().c_str(), "The op input size is zero"); + return false; + } + return true; +} + +bool CheckDynamicInputDtype(const Operator& op, const string& input_name1) { + DataType first_input_dtype = op.GetDynamicInputDesc(input_name1, 0).GetDataType(); + auto input_dynamic_size = op.GetInputsSize(); + for (size_t i = 0; i < input_dynamic_size; ++i) { + DataType input_dtype = op.GetDynamicInputDesc(input_name1, i).GetDataType(); + if (first_input_dtype != input_dtype) { + OpsInputDtypeErrReport(op.GetName(), "x", ConcatString(first_input_dtype), ConcatString(input_dtype)); + OP_LOGE(op.GetName().c_str(), + "the op type is not same," + "type1:%d,type2:%d", + input_dtype, first_input_dtype); + return false; + } + } + return true; +} + +IMPLEMT_VERIFIER(AccumulateNV2, AccumulateNV2Verify) { + if (CheckInputSize(op) == false) { + return GRAPH_FAILED; + } + if (CheckDynamicInputDtype(op, "x") == false) { + return GRAPH_FAILED; + } + int64_t num; + if (GRAPH_SUCCESS != op.GetAttr("N", num)) { + OpsGetAttrErrReport(op.GetName(), "N"); + OP_LOGE(op.GetName().c_str(), "GetAttr of N failed."); + return GRAPH_FAILED; + } else { + if (op.GetInputsSize() != static_cast(num)) { + OpsInputShapeErrReport(op.GetName(), "input size and N must be same.", "N", + ConcatString(static_cast(num))); + OP_LOGE(op.GetName().c_str(), "input size and N must be same."); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +int64_t GetAccumulateNV2ConstValue(const ge::Operator& op) { + int64_t tensor_num; + if (ge::GRAPH_SUCCESS != op.GetAttr("N", tensor_num)) { + OpsGetAttrErrReport(op.GetName(), "N"); + OP_LOGE(op.GetName().c_str(), "The add_n op GetOpAttr failed!"); + } + return tensor_num; +} + +IMPLEMT_COMMON_INFERFUNC(AccumulateNV2InferShape) { + /* + Accumulate_nv2 has four type inputs: + 1.empty 2.static shape 3.-1 4.-2 + The combinations bring 15 scenes, and the 15 scenes can be classify into 4 categories: + 1.input with no range and output no need range, and it can be divided half: + 1.1 all input is empty + 1.2 input only contains empty and -2 shape + 2.input contains static shape and with no -1 shape + 3.input contains -1 shape + */ + const int64_t infer_condition_one_one = 11; + const int64_t infer_condition_one_two = 12; + const int64_t infer_condition_two = 2; + const int64_t infer_condition_three = 3; + + int64_t empty_num = 0; + int64_t static_num = 0; + int64_t dynamic_shape_num = 0; + int64_t dynamic_dim_num = 0; + int64_t infer_classify = 0; + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + int64_t tensor_num = GetAccumulateNV2ConstValue(op); + + for (uint32_t i = 0; i < tensor_num; i++) { + auto input_desc = op_info->MutableInputDesc(i); + vector tempVector = input_desc->MutableShape().GetDims(); + if (tempVector.empty()) { + empty_num++; + } else if (std::find(tempVector.begin(), tempVector.end(), -1) != tempVector.end()) { + dynamic_shape_num++; + } else if (std::find(tempVector.begin(), tempVector.end(), -2) != tempVector.end()) { + dynamic_dim_num++; + } else { + static_num++; + } + } + if (tensor_num == empty_num + dynamic_dim_num) { + if (tensor_num == empty_num) { + infer_classify = infer_condition_one_one; + } else { + infer_classify = infer_condition_one_two; + } + } else if (tensor_num == static_num || tensor_num == empty_num + static_num || tensor_num == static_num + + dynamic_dim_num || tensor_num == empty_num + static_num + dynamic_dim_num) { + infer_classify = infer_condition_two; + } else { + infer_classify = infer_condition_three; + } + + // condition 1: all input shape is empty + if (infer_classify == 11) { + auto input_desc = op_info->MutableInputDesc(0); + std::vector shape_vector = input_desc->MutableShape().GetDims(); + DataType x_dtype = input_desc->GetDataType(); + auto y_desc = op_info->MutableOutputDesc("y"); + y_desc->SetShape(GeShape(shape_vector)); + y_desc->SetDataType(x_dtype); + } else if (infer_classify == 12) { + auto input_desc = op_info->MutableInputDesc( 0); + std::vector shape_vector = {-2}; + DataType x_dtype = input_desc->GetDataType(); + auto y_desc = op_info->MutableOutputDesc("y"); + y_desc->SetShape(GeShape(shape_vector)); + y_desc->SetDataType(x_dtype); + } else if (infer_classify == 2) { + auto input_desc = op_info->MutableInputDesc(0); + std::vector shape_vector = input_desc->MutableShape().GetDims(); + DataType x_dtype = input_desc->GetDataType(); + for (int64_t i = 0; i < tensor_num; i++) { + auto input_desc = op_info->MutableInputDesc(i); + std::vector temp_vector = input_desc->MutableShape().GetDims(); + if (!shape_vector.empty() && !IsUnknownRankShape(shape_vector)) { + shape_vector = temp_vector; + break; + } + } + auto y_desc = op_info->MutableOutputDesc("y"); + y_desc->SetShape(GeShape(shape_vector)); + y_desc->SetDataType(x_dtype); + std::vector> out_range; + MakeUpShapeRange(shape_vector, out_range); + y_desc->SetShapeRange(out_range); + } else { + auto input_desc = op_info->MutableInputDesc(0); + std::vector out_shape = input_desc->MutableShape().GetDims(); + DataType x_dtype = input_desc->GetDataType(); + std::vector out_vector; + std::vector> out_range; + // Init the output shape and range + for (int64_t i = 0; i < tensor_num; i++) { + auto input_desc = op_info->MutableInputDesc(i); + + std::vector temp_vector = input_desc->MutableShape().GetDims(); + if (!temp_vector.empty() && !IsUnknownRankShape(temp_vector)) { + out_vector = temp_vector; + input_desc->GetShapeRange(out_range); + MakeUpShapeRange(out_vector, out_range); + break; + } + } + // compute the shape dims and range intersection + for (int64_t i = 0; i < tensor_num; i++) { + auto input_desc = op_info->MutableInputDesc(i); + std::vector temp_vector = input_desc->MutableShape().GetDims(); + if (temp_vector.empty() || IsUnknownRankShape(temp_vector)) { + continue; + } + std::vector> temp_range; + input_desc->GetShapeRange(temp_range); + MakeUpShapeRange(temp_vector, temp_range); + for (size_t j = 0; j < temp_vector.size(); j++) { + // two condition: const == const; const > -1 + if (temp_vector[j] >= out_vector[j]) { + out_vector[j] = temp_vector[j]; + // update range: left choose the max value + if (temp_range[j].first >= out_range[j].first) { + out_range[j].first = temp_range[j].first; + } + // update range: right choose the miner value but when it was > 0 + if ((temp_range[j].second <= out_range[j].second && temp_range[j].second > 0) || + (out_range[j].second == -1 && temp_range[j].second != -1)) { + out_range[j].second = temp_range[j].second; + } + } + } + } + auto y_desc = op_info->MutableOutputDesc("y"); + y_desc->SetShape(GeShape(out_shape)); + y_desc->SetShapeRange(out_range); + y_desc->SetDataType(x_dtype); + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AccumulateNV2, AccumulateNV2InferShape); +VERIFY_FUNC_REG(AccumulateNV2, AccumulateNV2Verify); +// --------------------AccumulateNV2 END----------------------- + +// -------------------Greater------------------- +IMPLEMT_VERIFIER(Greater, GreaterVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(GreaterInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto vec_y = op_desc->MutableOutputDesc("y")->MutableShape().GetDims(); + op_desc->MutableOutputDesc("y")->SetDataType(DT_BOOL); + if (IsUnknownRankShape(vec_y) || IsUnknownVec(vec_y)) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Greater, GreaterInferShape); +VERIFY_FUNC_REG(Greater, GreaterVerify); +// --------------------Greater END--------------------- + +// --------------------ZerosLike---------------- +COMMON_INFER_FUNC_REG(ZerosLike, OneInOneOutCommonInferShape); +// ----------------ZerosLike END----------------- + +// ----------------LogicalNot------------------- +IMPLEMT_COMMON_INFERFUNC(LogicalNotInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(LogicalNot, LogicalNotInferShape); +// --------------LogicalNot END----------------- + +// ----------------------LogicalAnd-------------------------- +IMPLEMT_COMMON_INFERFUNC(LogicalAndInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + if (!InferShapeRangeTwoInOneOutBroadcase(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(LogicalAnd, LogicalAndInferShape); +// ---------------------LogicalAnd END--------------------- + +// ----------------FakeQuantWithMinMaxVarsPerChannel---------------------------- +IMPLEMT_VERIFIER(FakeQuantWithMinMaxVarsPerChannel, FakeQuantWithMinMaxVarsPerChannelVerify) { + int64_t num_bits; + if (ge::GRAPH_SUCCESS != op.GetAttr("num_bits", num_bits)) { + OpsGetAttrErrReport(op.GetName(), "num_bits"); + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (ge::GRAPH_SUCCESS != op.GetAttr("narrow_range", narrow_range)) { + OpsGetAttrErrReport(op.GetName(), "narrow_range"); + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (num_bits < 2 || num_bits > 16) { + OpsAttrValueErrReport(op.GetName(), "num_bits", "between 2 and 16", ConcatString(num_bits)); + LOG_ERROR("[ERROR]op [%s] num_bits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + Shape shape_x = op.GetInputDesc("x").GetShape(); + Shape shape_min = op.GetInputDesc("min").GetShape(); + Shape shape_max = op.GetInputDesc("max").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_min = shape_min.GetDims(); + std::vector dims_max = shape_max.GetDims(); + if (dims_x.size() < 1) { + OpsAttrValueErrReport(op.GetName(), "x'shape", "equal or greater than 1", ConcatString(dims_x.size())); + OP_LOGE(op.GetName().c_str(), "shape of x must greater 1"); + return GRAPH_FAILED; + } + if ((dims_min.size() != 1) || (dims_max.size() != 1)) { + string input_value = ConcatString("[", dims_min.size(), "] and [", dims_max.size(), "]"); + OpsAttrValueErrReport(op.GetName(), "min's and max's shape", "rank 1", input_value); + OP_LOGE(op.GetName().c_str(), "shape of min and max must be rank 1"); + return GRAPH_FAILED; + } + if (dims_min[0] != dims_max[0]) { + string excepted_value = ConcatString("same as max[", dims_max[0], "]"); + OpsAttrValueErrReport(op.GetName(), "min'shape", excepted_value, ConcatString(dims_min[0])); + OP_LOGE(op.GetName().c_str(), "shape of min and max must be same"); + return GRAPH_FAILED; + } + if (dims_x[dims_x.size() - 1] != dims_min[0]) { + string excepted_value = ConcatString("same as min[", dims_min[0], "]"); + OpsAttrValueErrReport(op.GetName(), "x'last dimension", excepted_value, ConcatString(dims_x[dims_x.size() - 1])); + OP_LOGE(op.GetName().c_str(), + "The last dimension of x must" + " be the same as min"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxVarsPerChannelInferShape) { + Shape shape_input = op.GetInputDesc("x").GetShape(); + DataType dtype_input = op.GetInputDesc("x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape_input); + tensordesc_output.SetDataType(dtype_input); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxVarsPerChannel, FakeQuantWithMinMaxVarsPerChannelInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxVarsPerChannel, FakeQuantWithMinMaxVarsPerChannelVerify); +// ----------------FakeQuantWithMinMaxVarsPerChannel---------------------------- + +// ----------------Rint----------------------------- +COMMON_INFER_FUNC_REG(Rint, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// ----------------Rint END------------------------- + +// --------------------------------BiasAdd------------------------------------- +IMPLEMT_VERIFIER(BiasAdd, BiasAddVerify) { + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(BiasAddInferShape) { + if (!OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_FAILED; + } + std::string data_format; + if (op.GetAttr("data_format", data_format) == GRAPH_FAILED) { + OpsGetAttrErrReport(op.GetName(), "data_format"); + OP_LOGE(op.GetName().c_str(), "get attr N failed"); + } + if (data_format != "NHWC" && data_format != "NCHW" && data_format != "NDHWC" && data_format != "NCDHW") { + string expected_format_list = ConcatString("NHWC, NCHW, NDHWC, NCDHW"); + OpsInputFormatErrReport(op.GetName(), "data_format", expected_format_list, data_format); + OP_LOGE(op.GetName().c_str(), + "data_format only " + "support 'NHWC', 'NCHW', 'NDHWC' and 'NCDHW'."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(BiasAdd, BiasAddInferShape); +VERIFY_FUNC_REG(BiasAdd, BiasAddVerify); +// ----------------------------------BiasAdd END----------------------------- + +// -------------------BitwiseAnd---------------------------- +IMPLEMT_COMMON_INFERFUNC(BitwiseAndInferShape) { + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y") == false) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(BitwiseAnd, BitwiseAndInferShape); +// ----------------BitwiseAnd END-------------------------- + +// ---------------------BitwiseOr---------------------------- +IMPLEMT_COMMON_INFERFUNC(BitwiseOrInferShape) { + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y") == false) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(BitwiseOr, BitwiseOrInferShape); +// --------------------BitwiseOr END------------------------ + +// -----------------------BitwiseXor------------------------- +IMPLEMT_COMMON_INFERFUNC(BitwiseXorInferShape) { + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y") == false) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(BitwiseXor, BitwiseXorInferShape); +// ------------------BitwiseXor END------------------------- + +// ----------------FakeQuantWithMinMaxArgs------------------ +IMPLEMT_VERIFIER(FakeQuantWithMinMaxArgs, FakeQuantWithMinMaxArgsVerify) { + float min; + if (GetConstValue(op, "min", min) == false) { + LOG_ERROR("[ERROR]op [%s] Attr min is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + float max; + if (GetConstValue(op, "max", max) == false) { + LOG_ERROR("[ERROR]op [%s] Attr max is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + int64_t numBits; + if (GetConstValue(op, "num_bits", numBits) == false) { + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (GetConstValue(op, "narrow_range", narrow_range) == false) { + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (min >= max) { + string excepted_value = ConcatString("less than max[", max, "]"); + OpsAttrValueErrReport(op.GetName(), "min", excepted_value, ConcatString(min)); + LOG_ERROR("[ERROR]op [%s] min must be less than max !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (numBits < 2 || numBits > 16) { + OpsAttrValueErrReport(op.GetName(), "numBits", "between 2 and 16", ConcatString(numBits)); + LOG_ERROR("[ERROR]op [%s] numBits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxArgsInferShape) { + Shape shape = op.GetInputDesc("x").GetShape(); + DataType input_dtype = op.GetInputDesc("x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxArgs, FakeQuantWithMinMaxArgsInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxArgs, FakeQuantWithMinMaxArgsVerify); +// ----------------FakeQuantWithMinMaxArgs END---------------------- + +// ----------------FakeQuantWithMinMaxArgsGradient----------------- +IMPLEMT_VERIFIER(FakeQuantWithMinMaxArgsGradient, FakeQuantWithMinMaxArgsGradientVerify) { + float min; + if (GetConstValue(op, "min", min) == false) { + LOG_ERROR("[ERROR]op [%s] Attr min is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + float max; + if (GetConstValue(op, "max", max) == false) { + LOG_ERROR("[ERROR]op [%s] Attr max is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + int64_t num_bits; + if (GetConstValue(op, "num_bits", num_bits) == false) { + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (GetConstValue(op, "narrow_range", narrow_range) == false) { + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (min >= max) { + string excepted_value = ConcatString("less than max[", max, "]"); + OpsAttrValueErrReport(op.GetName(), "min", excepted_value, ConcatString(min)); + LOG_ERROR("[ERROR]op [%s] min must be less than max !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (num_bits < 2 || num_bits > 16) { + OpsAttrValueErrReport(op.GetName(), "num_bits", "between 2 and 16", ConcatString(num_bits)); + LOG_ERROR("[ERROR]op [%s] num_bits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "x", "gradients")) { + return GRAPH_FAILED; + } + Shape shape_x = op.GetInputDesc("x").GetShape(); + Shape shape_y = op.GetInputDesc("gradients").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() != dims_y.size()) { + string excepted_value = ConcatString("same as gradients[", dims_y.size(), "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x.size())); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } else { + for (size_t i = 0; i < dims_x.size(); i++) { + if (dims_x[i] != dims_y[i]) { + string excepted_value = ConcatString("same as gradients[", dims_y[i], "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x[i])); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxArgsGradientInferShape) { + Shape shape_x = op.GetInputDesc("x").GetShape(); + DataType input_dtype = op.GetInputDesc("x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape_x); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxArgsGradient, FakeQuantWithMinMaxArgsGradientInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxArgsGradient, FakeQuantWithMinMaxArgsGradientVerify); +// ----------------FakeQuantWithMinMaxArgsGradient------------------- + +// ----------------FakeQuantWithMinMaxVars--------------------------- +IMPLEMT_VERIFIER(FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsVerify) { + int64_t num_bits; + if (GetConstValue(op, "num_bits", num_bits) == false) { + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (GetConstValue(op, "narrow_range", narrow_range) == false) { + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "x", "min")) { + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "min", "max")) { + return GRAPH_FAILED; + } + if (num_bits < 2 || num_bits > 16) { + OpsAttrValueErrReport(op.GetName(), "num_bits", "between 2 and 16", ConcatString(num_bits)); + LOG_ERROR("[ERROR]op [%s] num_bits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxVarsInferShape) { + Shape shape = op.GetInputDesc("x").GetShape(); + DataType input_dtype = op.GetInputDesc("x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsVerify); +// ----------------FakeQuantWithMinMaxVars-------------------------------------- + +// ----------------FakeQuantWithMinMaxVarsGradient------------------------------ +IMPLEMT_VERIFIER(FakeQuantWithMinMaxVarsGradient, FakeQuantWithMinMaxVarsGradientVerify) { + int64_t num_bits; + if (GetConstValue(op, "num_bits", num_bits) == false) { + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (GetConstValue(op, "narrow_range", narrow_range) == false) { + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "x", "gradients")) { + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "min", "max")) { + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "max", "gradients")) { + return GRAPH_FAILED; + } + if (num_bits < 2 || num_bits > 16) { + OpsAttrValueErrReport(op.GetName(), "num_bits", "between 2 and 16", ConcatString(num_bits)); + LOG_ERROR("[ERROR]op [%s] num_bits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + Shape shape_x = op.GetInputDesc("x").GetShape(); + Shape shape_y = op.GetInputDesc("gradients").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() != dims_y.size()) { + string excepted_value = ConcatString("same as gradients[", dims_y.size(), "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x.size())); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } else { + for (size_t i = 0; i < dims_x.size(); i++) { + if (dims_x[i] != dims_y[i]) { + string excepted_value = ConcatString("same as gradients[", dims_y[i], "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x[i])); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxVarsGradientInferShape) { + Shape shape_input_gradients = op.GetInputDesc("gradients").GetShape(); + Shape shape_input_min = op.GetInputDesc("min").GetShape(); + Shape shape_input_max = op.GetInputDesc("max").GetShape(); + DataType dtype_input_gradients = op.GetInputDesc("gradients").GetDataType(); + DataType dtype_input_min = op.GetInputDesc("min").GetDataType(); + DataType dtype_input_max = op.GetInputDesc("max").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("backprops_wrt_x"); + TensorDesc tensordesc_output_min = op.GetOutputDesc("backprops_wrt_min"); + TensorDesc tensordesc_output_max = op.GetOutputDesc("backprops_wrt_max"); + tensordesc_output.SetShape(shape_input_gradients); + tensordesc_output_min.SetShape(shape_input_min); + tensordesc_output_max.SetShape(shape_input_max); + tensordesc_output.SetDataType(dtype_input_gradients); + tensordesc_output_min.SetDataType(dtype_input_min); + tensordesc_output_max.SetDataType(dtype_input_max); + (void)op.UpdateOutputDesc("backprops_wrt_x", tensordesc_output); + (void)op.UpdateOutputDesc("backprops_wrt_min", tensordesc_output_min); + (void)op.UpdateOutputDesc("backprops_wrt_max", tensordesc_output_max); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxVarsGradient, FakeQuantWithMinMaxVarsGradientInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxVarsGradient, FakeQuantWithMinMaxVarsGradientVerify); +// ----------------FakeQuantWithMinMaxVarsGradient END--------------------- + +// ----------------FakeQuantWithMinMaxVarsPerChannelGradient--------------- +IMPLEMT_VERIFIER(FakeQuantWithMinMaxVarsPerChannelGradient, FakeQuantWithMinMaxVarsPerChannelGradientVerify) { + int64_t num_bits; + if (GetConstValue(op, "num_bits", num_bits) == false) { + LOG_ERROR("[ERROR]op [%s] Attr num_bits is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + bool narrow_range; + if (GetConstValue(op, "narrow_range", narrow_range) == false) { + LOG_ERROR("[ERROR]op [%s] Attr narrow_range is empty !\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "x", "gradients")) { + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "min", "max")) { + return GRAPH_FAILED; + } + if (!CheckTwoInputDtypeSame(op, "max", "gradients")) { + return GRAPH_FAILED; + } + if (num_bits < 2 || num_bits > 16) { + OpsAttrValueErrReport(op.GetName(), "num_bits", "between 2 and 16", ConcatString(num_bits)); + LOG_ERROR("[ERROR]op [%s] num_bits is between 2 and 16\n", op.GetName().c_str()); + return GRAPH_FAILED; + } + Shape shape_x = op.GetInputDesc("x").GetShape(); + Shape shape_min = op.GetInputDesc("min").GetShape(); + Shape shape_max = op.GetInputDesc("max").GetShape(); + Shape shape_y = op.GetInputDesc("gradients").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_min = shape_min.GetDims(); + std::vector dims_max = shape_max.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() != dims_y.size()) { + string excepted_value = ConcatString("same as gradients[", dims_y.size(), "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x.size())); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } else { + for (size_t i = 0; i < dims_x.size(); i++) { + if (dims_x[i] != dims_y[i]) { + string excepted_value = ConcatString("same as gradients[", dims_y[i], "]"); + OpsAttrValueErrReport(op.GetName(), "x'shape", excepted_value, ConcatString(dims_x[i])); + OP_LOGE(op.GetName().c_str(), "two input shape not same"); + return GRAPH_FAILED; + } + } + } + if ((dims_min.size() != 1) || (dims_max.size() != 1)) { + string input_value = ConcatString("[", dims_min.size(), "] and [", dims_max.size(), "]"); + OpsAttrValueErrReport(op.GetName(), "min's and max's shape", "rank 1", input_value); + OP_LOGE(op.GetName().c_str(), "shape of min and max must be rank 1"); + return GRAPH_FAILED; + } + if (dims_min[0] != dims_max[0]) { + string excepted_value = ConcatString("same as max[", dims_max[0], "]"); + OpsAttrValueErrReport(op.GetName(), "min'shape", excepted_value, ConcatString(dims_min[0])); + OP_LOGE(op.GetName().c_str(), "shape of min and max must be same"); + return GRAPH_FAILED; + } + if (dims_x[dims_x.size() - 1] != dims_min[0]) { + string excepted_value = ConcatString("same as min[", dims_min[0], "]"); + OpsAttrValueErrReport(op.GetName(), "x'last dimension", excepted_value, ConcatString(dims_x[dims_x.size() - 1])); + OP_LOGE(op.GetName().c_str(), + "The last dimension of x " + "must be the same as min"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FakeQuantWithMinMaxVarsPerChannelGradientInferShape) { + Shape shape_input_gradients = op.GetInputDesc("gradients").GetShape(); + Shape shape_input_min = op.GetInputDesc("min").GetShape(); + Shape shape_input_max = op.GetInputDesc("max").GetShape(); + DataType dtype_input_gradients = op.GetInputDesc("gradients").GetDataType(); + DataType dtype_input_min = op.GetInputDesc("min").GetDataType(); + DataType dtype_input_max = op.GetInputDesc("max").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("backprops_wrt_x"); + TensorDesc tensordesc_output_min = op.GetOutputDesc("backprops_wrt_min"); + TensorDesc tensordesc_output_max = op.GetOutputDesc("backprops_wrt_max"); + tensordesc_output.SetShape(shape_input_gradients); + tensordesc_output_min.SetShape(shape_input_min); + tensordesc_output_max.SetShape(shape_input_max); + tensordesc_output.SetDataType(dtype_input_gradients); + tensordesc_output_min.SetDataType(dtype_input_min); + tensordesc_output_max.SetDataType(dtype_input_max); + (void)op.UpdateOutputDesc("backprops_wrt_x", tensordesc_output); + (void)op.UpdateOutputDesc("backprops_wrt_min", tensordesc_output_min); + (void)op.UpdateOutputDesc("backprops_wrt_max", tensordesc_output_max); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FakeQuantWithMinMaxVarsPerChannelGradient, FakeQuantWithMinMaxVarsPerChannelGradientInferShape); +VERIFY_FUNC_REG(FakeQuantWithMinMaxVarsPerChannelGradient, FakeQuantWithMinMaxVarsPerChannelGradientVerify); +// ----------------FakeQuantWithMinMaxVarsPerChannelGradient-------------------- + +// -------------------FloorDiv----------------------- +IMPLEMT_VERIFIER(FloorDiv, FloorDivVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FloorDiv, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(FloorDiv, FloorDivVerify); +// ----------------FloorDiv END------------------------ + +// ------------------FloorMod-------------------------- +IMPLEMT_VERIFIER(FloorMod, FloorModVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FloorMod, TwoInOneOutCommonInferShape); +VERIFY_FUNC_REG(FloorMod, FloorModVerify); +// ----------------FloorMod END--------------------- + +// ---------------------Pow------------------------- +IMPLEMT_VERIFIER(Pow, PowVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(PowInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Pow, PowInferShape); +VERIFY_FUNC_REG(Pow, PowVerify); +// -------------------Pow END------------------------ + +// ----------------Round------------------------------------- +COMMON_INFER_FUNC_REG(Round, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// ----------------Round END--------------------------------- + +// ---------------------------------ArgMin-------------------------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMinInferShape) { + // get all input desc + const vector depend_names = {"dimension"}; + PREPARE_DYNAMIC_SHAPE(depend_names); + auto node = NodeUtils::GetNodeFromOperator(op); + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc = op_info->MutableInputDesc("x"); + auto const_desc = op_info->MutableInputDesc("dimension"); + auto y_desc = op_info->MutableOutputDesc("y"); + // get x shape + auto x_shape = input_desc->MutableShape().GetDims(); + + // get and set output dtype + ge::DataType dtype; + if (op.GetAttr("dtype", dtype) == GRAPH_SUCCESS) { + y_desc->SetDataType(dtype); + } else { + OP_LOGE(op.GetName().c_str(), "get attr dtype failed."); + return GRAPH_FAILED; + } + + // if x_shape == -2, set output -2 + if (IsUnknownRankShape(x_shape)) { + y_desc->SetShape(GeShape(x_shape)); + return GRAPH_SUCCESS; + } + + // if x_shape.size() < 2, set output scalar + if (x_shape.size() < 2) { + vector output_shape; + y_desc->SetShape(GeShape(output_shape)); + return GRAPH_SUCCESS; + } + + // read dimension const value + GeTensorPtr dimension_tensor = nullptr; + vector dimension_value; + if (GRAPH_SUCCESS == NodeUtils::GetInputConstData(node, "dimension", dimension_tensor)) { + auto const_dtype = const_desc->GetDataType(); + GetConstValue(op, dimension_tensor, const_dtype, dimension_value); + // verify dimension_value + if (dimension_value.size() != 1) { + OP_LOGE(op.GetName().c_str(), "The length of dimension value must be equal to 1, but got %d.", + dimension_value.size()); + return GRAPH_FAILED; + } + int64_t dimension = dimension_value[0] < 0 ? dimension_value[0] + x_shape.size() : dimension_value[0]; + if (dimension >= x_shape.size()) { + OP_LOGE(op.GetName().c_str(), + "The dimension value must be range at input shape size, but got dimension value %d, input shape size %d.", + dimension_value[0], x_shape.size()); + return GRAPH_FAILED; + } + + vector output_shape(x_shape); + output_shape.erase(output_shape.begin() + dimension); + y_desc->SetShape(GeShape(output_shape)); + + // when output is dynamic will update range + if (IsUnknown(output_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(x_shape, input_range); + input_range.erase(input_range.begin() + dimension); + y_desc->SetShapeRange(input_range); + } + return GRAPH_SUCCESS; + } + + // dimension is not const, set all output is -1, range is [1, -1] + vector output_shape; + std::vector> output_range; + for (int64_t item = 0; item < (x_shape.size() - 1); ++item) { + output_shape.push_back(-1); + } + MakeUpShapeRange(output_shape, output_range); + y_desc->SetShape(GeShape(output_shape)); + y_desc->SetShapeRange(output_range); + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ArgMin, ArgMinInferShape); +// -------------------------------ArgMin---------------------------------------- + +// --------------------------------ArgMinD-------------------------------------- + +IMPLEMT_COMMON_INFERFUNC(ArgMinDInferShape) { + // get all input desc + auto node = NodeUtils::GetNodeFromOperator(op); + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc = op_info->MutableInputDesc("x"); + auto y_desc = op_info->MutableOutputDesc("y"); + // get x shape + auto x_shape = input_desc->MutableShape().GetDims(); + + // set output dtype + y_desc->SetDataType(DT_INT32); + + // if x_shape == -2, set output -2 + if (IsUnknownRankShape(x_shape)) { + y_desc->SetShape(GeShape(x_shape)); + return GRAPH_SUCCESS; + } + + // if x_shape.size() < 2, set output scalar + if (x_shape.size() < 2) { + vector output_shape; + y_desc->SetShape(GeShape(output_shape)); + return GRAPH_SUCCESS; + } + + int64_t dimension; + if (GRAPH_SUCCESS != op.GetAttr("dimension", dimension)) { + OpsGetAttrErrReport(op.GetName(), "dimension"); + OP_LOGE(op.GetName().c_str(), "GetAttr dimension failed."); + return GRAPH_FAILED; + } + if (dimension < 0) { + dimension += x_shape.size(); + } + + vector output_shape(x_shape); + output_shape.erase(output_shape.begin() + dimension); + y_desc->SetShape(GeShape(output_shape)); + + // when output is dynamic will update range + if (IsUnknown(output_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(x_shape, input_range); + input_range.erase(input_range.begin() + dimension); + y_desc->SetShapeRange(input_range); + } + + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(ArgMinD, ArgMinDInferShape); +// ------------------------------ArgMinD---------------------------------------- + +// -----------------------------ArgMax------------------------------------------ +IMPLEMT_COMMON_INFERFUNC(ArgMaxInferShape) { + // get all input desc + const vector depend_names = {"dimension"}; + PREPARE_DYNAMIC_SHAPE(depend_names); + auto node = NodeUtils::GetNodeFromOperator(op); + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc = op_info->MutableInputDesc("x"); + auto const_desc = op_info->MutableInputDesc("dimension"); + auto y_desc = op_info->MutableOutputDesc("y"); + // get x shape + auto x_shape = input_desc->MutableShape().GetDims(); + + // get and set output dtype + ge::DataType dtype; + if (op.GetAttr("dtype", dtype) == GRAPH_SUCCESS) { + y_desc->SetDataType(dtype); + } else { + OP_LOGE(op.GetName().c_str(), "get attr dtype failed."); + return GRAPH_FAILED; + } + + // if x_shape == -2, set output -2 + if (IsUnknownRankShape(x_shape)) { + y_desc->SetShape(GeShape(x_shape)); + return GRAPH_SUCCESS; + } + + // if x_shape.size() < 2, set output scalar + if (x_shape.size() < 2) { + vector output_shape; + y_desc->SetShape(GeShape(output_shape)); + return GRAPH_SUCCESS; + } + + // read dimension const value + GeTensorPtr dimension_tensor = nullptr; + vector dimension_value; + if (GRAPH_SUCCESS == NodeUtils::GetInputConstData(node, "dimension", dimension_tensor)) { + auto const_dtype = const_desc->GetDataType(); + GetConstValue(op, dimension_tensor, const_dtype, dimension_value); + // verify dimension_value + if (dimension_value.size() != 1) { + OP_LOGE(op.GetName().c_str(), "The length of dimension value must be equal to 1, but got %d.", + dimension_value.size()); + return GRAPH_FAILED; + } + int64_t dimension = dimension_value[0] < 0 ? dimension_value[0] + x_shape.size() : dimension_value[0]; + if (dimension >= x_shape.size()) { + OP_LOGE(op.GetName().c_str(), + "The dimension value must be range at input shape size, but got dimension value %d, input shape size %d.", + dimension_value[0], x_shape.size()); + return GRAPH_FAILED; + } + + vector output_shape(x_shape); + output_shape.erase(output_shape.begin() + dimension); + y_desc->SetShape(GeShape(output_shape)); + + // when output is dynamic will update range + if (IsUnknown(output_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(x_shape, input_range); + input_range.erase(input_range.begin() + dimension); + y_desc->SetShapeRange(input_range); + } + return GRAPH_SUCCESS; + } + + // dimension is not const, set all output is -1, range is [1, -1] + vector output_shape; + std::vector> output_range; + for (int64_t item = 0; item < (x_shape.size() - 1); ++item) { + output_shape.push_back(-1); + } + MakeUpShapeRange(output_shape, output_range); + y_desc->SetShape(GeShape(output_shape)); + y_desc->SetShapeRange(output_range); + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ArgMaxV2, ArgMaxInferShape); +// --------------------------ArgMax--------------------------------------------- + +// --------------------------------ArgMaxD-------------------------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMaxDInferShape) { + // get all input desc + auto node = NodeUtils::GetNodeFromOperator(op); + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + auto input_desc = op_info->MutableInputDesc("x"); + auto y_desc = op_info->MutableOutputDesc("y"); + // get x shape + auto x_shape = input_desc->MutableShape().GetDims(); + + // set output dtype + y_desc->SetDataType(DT_INT32); + + // if x_shape == -2, set output -2 + if (IsUnknownRankShape(x_shape)) { + y_desc->SetShape(GeShape(x_shape)); + return GRAPH_SUCCESS; + } + + // if x_shape.size() < 2, set output scalar + if (x_shape.size() < 2) { + vector output_shape; + y_desc->SetShape(GeShape(output_shape)); + return GRAPH_SUCCESS; + } + + int64_t dimension; + if (GRAPH_SUCCESS != op.GetAttr("dimension", dimension)) { + OpsGetAttrErrReport(op.GetName(), "dimension"); + OP_LOGE(op.GetName().c_str(), "GetAttr dimension failed."); + return GRAPH_FAILED; + } + if (dimension < 0) { + dimension += x_shape.size(); + } + + vector output_shape(x_shape); + output_shape.erase(output_shape.begin() + dimension); + y_desc->SetShape(GeShape(output_shape)); + + // when output is dynamic will update range + if (IsUnknown(output_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(x_shape, input_range); + input_range.erase(input_range.begin() + dimension); + y_desc->SetShapeRange(input_range); + } + + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(ArgMaxD, ArgMaxDInferShape); +// ------------------------------ArgMaxD---------------------------------------- + +// ----------------------------ArgMaxWithValue---------------------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMaxWithValueInferShape) { + auto tensordesc = op.GetInputDesc("x"); + auto shape_x = tensordesc.GetShape(); + int64_t dimension; + if (GRAPH_SUCCESS != op.GetAttr("dimension", dimension)) { + OpsGetAttrErrReport(op.GetName(), "dimension"); + OP_LOGE(op.GetName().c_str(), "GetAttr dimension failed."); + return GRAPH_FAILED; + } + if (dimension < 0) { + dimension += shape_x.GetDimNum(); + } + auto dim_num = shape_x.GetDimNum(); + vector y_shape; + for (size_t i = 0; i < dim_num; ++i) { + y_shape.push_back(shape_x.GetDim(i)); + } + int64_t max_size = y_shape.size(); + if (max_size != 0) { + dimension = dimension % max_size; + } + OP_LOGI(op.GetName().c_str(), "the dimension is %d.", (int)dimension); + + bool keep_dims; + if (GRAPH_SUCCESS != op.GetAttr("keep_dims", keep_dims)) { + OpsGetAttrErrReport(op.GetName(), "keep_dims"); + OP_LOGE(op.GetName().c_str(), "GetAttr of keep_dims failed."); + return GRAPH_FAILED; + } + if (keep_dims) { + // If keepDims is true, current dimesion set to 1 + y_shape[dimension] = 1; + } else { + y_shape.erase(y_shape.begin() + dimension); + } + + Shape outputShape(y_shape); + DataType input_dtype = tensordesc.GetDataType(); + TensorDesc td = op.GetOutputDesc("indice"); + TensorDesc td2 = op.GetOutputDesc("values"); + td.SetShape(outputShape); + td2.SetShape(outputShape); + td.SetDataType(DT_INT32); + td2.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("indice", td); + (void)op.UpdateOutputDesc("values", td2); + + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(ArgMaxWithValue, ArgMaxWithValueInferShape); +// -----------------------------ArgMaxWithValue--------------------------------- + +// ---------------------------ArgMinWithValue----------------------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMinWithValueInferShape) { + auto tensordesc = op.GetInputDesc("x"); + auto shape_x = tensordesc.GetShape(); + int64_t dimension; + if (GRAPH_SUCCESS != op.GetAttr("dimension", dimension)) { + OpsGetAttrErrReport(op.GetName(), "dimension"); + OP_LOGE(op.GetName().c_str(), "GetAttr dimension failed."); + return GRAPH_FAILED; + } + if (dimension < 0) { + dimension += shape_x.GetDimNum(); + } + auto dim_num = shape_x.GetDimNum(); + vector y_shape; + for (size_t i = 0; i < dim_num; ++i) { + y_shape.push_back(shape_x.GetDim(i)); + } + int64_t max_size = y_shape.size(); + dimension = dimension % max_size; + OP_LOGI(op.GetName().c_str(), "the dimension is %d.", (int)dimension); + + bool keep_dims; + if (GRAPH_SUCCESS != op.GetAttr("keep_dims", keep_dims)) { + OpsGetAttrErrReport(op.GetName(), "keep_dims"); + OP_LOGE(op.GetName().c_str(), "GetAttr of keep_dims failed."); + return GRAPH_FAILED; + } + if (keep_dims) { + // If keepDims is true, current dimesion set to 1 + y_shape[dimension] = 1; + } else { + y_shape.erase(y_shape.begin() + dimension); + } + + Shape outputShape(y_shape); + DataType input_dtype = tensordesc.GetDataType(); + TensorDesc td = op.GetOutputDesc("indice"); + TensorDesc td2 = op.GetOutputDesc("values"); + td.SetShape(outputShape); + td2.SetShape(outputShape); + td.SetDataType(DT_INT32); + td2.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("indice", td); + (void)op.UpdateOutputDesc("values", td2); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ArgMinWithValue, ArgMinWithValueInferShape); +// ---------------------------ArgMinWithValue----------------------------------- + +// ----------------Eltwise------------------- +IMPLEMT_COMMON_INFERFUNC(EltwiseInferShape) { + uint32_t first_input_index = 0; + TensorDesc td = op.GetDynamicInputDesc("x", first_input_index); + auto x_shape = td.GetShape().GetDims(); + auto x_dtype = td.GetDataType(); + TensorDesc td1 = op.GetOutputDesc("y"); + td1.SetShape(ge::Shape(x_shape)); + td1.SetDataType((DataType)x_dtype); + (void)op.UpdateOutputDesc("y", td1); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Eltwise, EltwiseInferShape); +// ----------------Eltwise END------------------- + +// ------------PopulationCount---------------- +IMPLEMT_COMMON_INFERFUNC(PopulationCountInferShape) { + Shape shape = op.GetInputDesc("x").GetShape(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape); + tensordesc_output.SetDataType(DT_UINT8); + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(PopulationCount, PopulationCountInferShape); +// ------------PopulationCount END----------------- + +// ------------LambNextMVWithDecay---------------- +IMPLEMT_COMMON_INFERFUNC(LambNextMVWithDecayInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul3", "input_mul2", "y1")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul2", "input_realdiv1", "y3")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul2", "input_mul1", "y2")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul3", "input_mul0", "y4")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(LambNextMVWithDecay, LambNextMVWithDecayInferShape); +// ------------LambNextMVWithDecay END---------------- + +// ------------LambNextMV---------------- +IMPLEMT_COMMON_INFERFUNC(LambNextMVInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul3", "input_mul2", "y1")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul2", "input_realdiv1", "y3")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul2", "input_mul1", "y2")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_mul3", "input_mul0", "y4")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(LambNextMV, LambNextMVInferShape); +// ------------LambNextMV END---------------- + +// ------------LambNextRight---------------- +IMPLEMT_COMMON_INFERFUNC(LambNextRightInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_square", "input_mul2", "y1")) { + return GRAPH_FAILED; + } + + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_square", "input_mul2", "y2")) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(LambNextRight, LambNextRightVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(LambNextRight, LambNextRightInferShape); +VERIFY_FUNC_REG(LambNextRight, LambNextRightVerify); +// ------------LambNextRight---------------- + +// ------------LambUpdateWithLr---------------- +IMPLEMT_COMMON_INFERFUNC(LambUpdateWithLrInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input_greater1", "input_sub", "y")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(LambUpdateWithLr, LambUpdateWithLrInferShape); +// ------------LambUpdateWithLr END---------------- + +// ------------LambUpdateWithLrV2---------------- +IMPLEMT_COMMON_INFERFUNC(LambUpdateWithLrV2InferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x4", "output_y")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(LambUpdateWithLrV2, LambUpdateWithLrV2Verify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(LambUpdateWithLrV2, LambUpdateWithLrV2InferShape); +VERIFY_FUNC_REG(LambUpdateWithLrV2, LambUpdateWithLrV2Verify); +// ------------LambUpdateWithLrV2---------------- + +// ----------------AdamApplyOneWithDecay------------------- +IMPLEMT_COMMON_INFERFUNC(AdamApplyOneWithDecayInferShape) { + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input1", "output0")) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input2", "output1")) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input2", "input3", "output2")) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(AdamApplyOneWithDecay, AdamApplyOneWithDecayVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(AdamApplyOneWithDecay, AdamApplyOneWithDecayInferShape); +VERIFY_FUNC_REG(AdamApplyOneWithDecay, AdamApplyOneWithDecayVerify); +// ----------------AdamApplyOneWithDecay------------------- + +// ----------------AdamApplyOne------------------- +IMPLEMT_COMMON_INFERFUNC(AdamApplyOneInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input1", "output0", is_dynamic_output)) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input2", "output1", is_dynamic_output)) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input2", "input3", "output2", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(AdamApplyOne, AdamApplyOneVerify) { + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AdamApplyOne, AdamApplyOneInferShape); +VERIFY_FUNC_REG(AdamApplyOne, AdamApplyOneVerify); +// ----------------AdamApplyOne------------------- + +// ----------------AdamApplyOneWithDecayAssign------------------- +IMPLEMT_COMMON_INFERFUNC(AdamApplyOneWithDecayAssignInferShape) { + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(AdamApplyOneWithDecayAssign, AdamApplyOneWithDecayAssignVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(AdamApplyOneWithDecayAssign, AdamApplyOneWithDecayAssignInferShape); +VERIFY_FUNC_REG(AdamApplyOneWithDecayAssign, AdamApplyOneWithDecayAssignVerify); +// ----------------AdamApplyOneWithDecayAssign------------------- + +// ----------------AdamApplyOneAssign------------------- +IMPLEMT_COMMON_INFERFUNC(AdamApplyOneAssignInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input1", "output0", is_dynamic_output)) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input0", "input2", "output1", is_dynamic_output)) { + return GRAPH_FAILED; + } + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "input2", "input3", "output2", is_dynamic_output)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(AdamApplyOneAssign, AdamApplyOneAssignVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(AdamApplyOneAssign, AdamApplyOneAssignInferShape); +VERIFY_FUNC_REG(AdamApplyOneAssign, AdamApplyOneAssignVerify); +// ----------------AdamApplyOneAssign------------------- + +// ----------------LambApplyOptimizerAssign------------------- +IMPLEMT_COMMON_INFERFUNC(LambApplyOptimizerAssignInferShape) { + Shape x_shape = op.GetInputDesc("grad").GetShape(); + DataType input_dtype = op.GetInputDesc("grad").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("output0"); + tensordesc_output.SetShape(x_shape); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("output0", tensordesc_output); + + Shape v_shape = op.GetInputDesc("inputv").GetShape(); + DataType inputv_dtype = op.GetInputDesc("inputv").GetDataType(); + TensorDesc tensordesc_outputv = op.GetOutputDesc("inputv"); + tensordesc_outputv.SetShape(v_shape); + tensordesc_outputv.SetDataType(inputv_dtype); + (void)op.UpdateOutputDesc("inputv", tensordesc_outputv); + + Shape m_shape = op.GetInputDesc("inputm").GetShape(); + DataType inputm_dtype = op.GetInputDesc("inputm").GetDataType(); + TensorDesc tensordesc_outputm = op.GetOutputDesc("inputm"); + tensordesc_outputm.SetShape(m_shape); + tensordesc_outputm.SetDataType(inputm_dtype); + (void)op.UpdateOutputDesc("inputm", tensordesc_outputm); + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(LambApplyOptimizerAssign, LambApplyOptimizerAssignVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(LambApplyOptimizerAssign, LambApplyOptimizerAssignInferShape); +VERIFY_FUNC_REG(LambApplyOptimizerAssign, LambApplyOptimizerAssignVerify); +// ----------------LambApplyOptimizerAssign------------------- + +// ----------------LambApplyWeightAssign------------------- +IMPLEMT_COMMON_INFERFUNC(LambApplyWeightAssignInferShape) { + Shape x_shape = op.GetInputDesc("input_param").GetShape(); + DataType input_dtype = op.GetInputDesc("input_param").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("input_param"); + tensordesc_output.SetShape(x_shape); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("input_param", tensordesc_output); + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(LambApplyWeightAssign, LambApplyWeightAssignVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(LambApplyWeightAssign, LambApplyWeightAssignInferShape); +VERIFY_FUNC_REG(LambApplyWeightAssign, LambApplyWeightAssignVerify); +// ----------------LambApplyWeightAssign------------------- + +// ------------SquareSumV2 Op Begin---------------- +IMPLEMT_COMMON_INFERFUNC(SquareSumV2InferShape) { + auto shape = op.GetInputDesc("input_x").GetShape(); + DataType input_dtype = op.GetInputDesc("input_x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("output1"); + std::vector shapeVector = shape.GetDims(); + int64_t dimNum = shape.GetDimNum(); + std::vector axis; + if (ge::GRAPH_SUCCESS != op.GetAttr("axis", axis)) { + OpsGetAttrErrReport(op.GetName(), "axis"); + OP_LOGE(op.GetName().c_str(), + "The input_size op GetOpAttr" + "ConstValue failed!"); + return GRAPH_FAILED; + } + + bool keep_dims; + if (ge::GRAPH_SUCCESS != op.GetAttr("keep_dims", keep_dims)) { + OpsGetAttrErrReport(op.GetName(), "keep_dims"); + OP_LOGE(op.GetName().c_str(), "get keep_dims op GetOpAttr failed!"); + return GRAPH_FAILED; + } + + if (axis.size() == 0) { + for (size_t i = 0; i < shapeVector.size(); ++i) { + axis.push_back(i); + } + } + + for (size_t i = 0; i < axis.size(); ++i) { + if (axis[i] < 0) { + axis[i] = dimNum + axis[i]; + } + } + + std::vector oShapeVector; + std::vector::iterator tmp; + for (int64_t item = 0; item < dimNum; ++item) { + tmp = std::find(axis.begin(), axis.end(), item); + if (tmp != axis.end()) { + // item in axis + // If keepDims is true, current dimesion set to 1 + if (keep_dims == true) { + oShapeVector.push_back(1); + } + } else { + // item is not in ConstValueAxis + oShapeVector.push_back(shapeVector[item]); + } + } + + Shape oShape(oShapeVector); + tensordesc_output.SetShape(oShape); + tensordesc_output.SetDataType(input_dtype); + TensorDesc tensordesc_output1 = op.GetOutputDesc("output2"); + tensordesc_output1.SetShape(shape); + tensordesc_output1.SetDataType(input_dtype); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(SquareSumV2, SquareSumV2InferShape); +// ------------SquareSumV2 Op End---------------- + +// ------------ClipByNormNoDivSum---------------- +IMPLEMT_COMMON_INFERFUNC(ClipByNormNoDivSumInferShape) { + auto shape = op.GetInputDesc("x").GetShape(); + DataType input_dtype = op.GetInputDesc("x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(shape); + tensordesc_output.SetDataType(input_dtype); + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ClipByNormNoDivSum, ClipByNormNoDivSumInferShape); +// ------------ClipByNormNoDivSum---------------- + +// ------------SquareSumV1 Op Begin---------------- +IMPLEMT_COMMON_INFERFUNC(SquareSumV1InferShape) { + auto shape = op.GetInputDesc("input_x").GetShape(); + DataType input_dtype = op.GetInputDesc("input_x").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("output1"); + std::vector shapeVector = shape.GetDims(); + int64_t dimNum = shape.GetDimNum(); + std::vector axis; + if (ge::GRAPH_SUCCESS != op.GetAttr("axis", axis)) { + OpsGetAttrErrReport(op.GetName(), "axis"); + OP_LOGE(op.GetName().c_str(), + "The input_size op GetOpAttr" + "ConstValue failed!"); + return GRAPH_FAILED; + } + + bool keep_dims; + if (ge::GRAPH_SUCCESS != op.GetAttr("keep_dims", keep_dims)) { + OpsGetAttrErrReport(op.GetName(), "keep_dims"); + OP_LOGE(op.GetName().c_str(), "get keep_dims op GetOpAttr failed!"); + return GRAPH_FAILED; + } + + if (axis.size() == 0) { + for (size_t i = 0; i < shapeVector.size(); ++i) { + axis.push_back(i); + } + } + + for (size_t i = 0; i < axis.size(); ++i) { + if (axis[i] < 0) { + axis[i] = dimNum + axis[i]; + } + } + + std::vector oShapeVector; + std::vector::iterator tmp; + for (int64_t item = 0; item < dimNum; ++item) { + tmp = std::find(axis.begin(), axis.end(), item); + if (tmp != axis.end()) { + // item in axis + // If keepDims is true, current dimesion set to 1 + if (keep_dims == true) { + oShapeVector.push_back(1); + } + } else { + // item is not in ConstValueAxis + oShapeVector.push_back(shapeVector[item]); + } + } + + Shape oShape(oShapeVector); + tensordesc_output.SetShape(oShape); + tensordesc_output.SetDataType(input_dtype); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(SquareSumV1, SquareSumV1InferShape); +// ------------SquareSumV1 Op End---------------- + +// ----------------SquareSumAll Op Begin------------------- +IMPLEMT_COMMON_INFERFUNC(SquareSumALlInferShape) { + std::vector o_shape_vector; + Shape o_shape(o_shape_vector); + DataType input_x1_dtype = op.GetInputDesc("x1").GetDataType(); + DataType input_x2_dtype = op.GetInputDesc("x2").GetDataType(); + TensorDesc tensor_desc_y1 = op.GetOutputDesc("y1"); + TensorDesc tensor_desc_y2 = op.GetOutputDesc("y2"); + tensor_desc_y1.SetShape(o_shape); + tensor_desc_y1.SetDataType(input_x1_dtype); + tensor_desc_y2.SetShape(Shape(o_shape)); + tensor_desc_y2.SetDataType(input_x2_dtype); + if (op.UpdateOutputDesc("y1", tensor_desc_y1) != GRAPH_SUCCESS || + op.UpdateOutputDesc("y2", tensor_desc_y2) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "UpdateOutputDesc run failed. Check whether the names of outputs are matched."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(SquareSumAll, SquareSumALlInferShape); +// ----------------SquareSumAll Op End------------------- + +// ----------------FusedMulAddN------------------- +// Check the dtype and attr of the input tensor description. +IMPLEMT_VERIFIER(FusedMulAddN, FusedMulAddNVerify) { + const std::map> kInputTensorMap = { + {"x1", {DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16}}, {"x2", {DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16}}}; + const std::vector kSupportList = {DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16}; + if (!CheckInputDataType(op, "x3", kSupportList)) { + return GRAPH_FAILED; + } + + // input tensor params, must have same shape and dtype + if (!CheckInputDtypeAndShape(op, kInputTensorMap)) { + return GRAPH_FAILED; + } + + OP_LOGI(op.GetName().c_str(), "The op verify end"); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FusedMulAddN, ELMTWISE_INFER_SHAPEANDTYPE("x1", "y")); +VERIFY_FUNC_REG(FusedMulAddN, FusedMulAddNVerify); +// ----------------FusedMulAddN END------------------ +// ----------------FusedMulAddNL2loss------------------- +// Check the dtype and attr of the input tensor description. +IMPLEMT_VERIFIER(FusedMulAddNL2loss, FusedMulAddNL2lossVerify) { + const std::map> kInputTensorMap = {{"x1", {DT_FLOAT}}, {"x2", {DT_FLOAT}}}; + const std::vector kSupportList = {DT_FLOAT}; + if (!CheckInputDataType(op, "x3", kSupportList)) { + return GRAPH_FAILED; + } + + // input tensor params, must have same shape and dtype + if (!CheckInputDtypeAndShape(op, kInputTensorMap)) { + return GRAPH_FAILED; + } + + OP_LOGI(op.GetName().c_str(), "The op verify end"); + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(FusedMulAddNL2lossInferShape) { + std::vector o_shape_vector; + Shape o_shape(o_shape_vector); + auto shape_x = op.GetInputDesc("x1").GetShape(); + DataType input_dtype = op.GetInputDesc("x1").GetDataType(); + TensorDesc tensordesc_output1 = op.GetOutputDesc("y1"); + TensorDesc tensordesc_output2 = op.GetOutputDesc("y2"); + tensordesc_output1.SetShape(shape_x); + tensordesc_output1.SetDataType(input_dtype); + tensordesc_output2.SetShape(ge::Shape(o_shape)); + tensordesc_output2.SetDataType(input_dtype); + if (op.UpdateOutputDesc("y1", tensordesc_output1) != GRAPH_SUCCESS || + op.UpdateOutputDesc("y2", tensordesc_output2) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "UpdateOutputDesc run failed. Check whether the names of outputs are matched."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(FusedMulAddNL2loss, FusedMulAddNL2lossInferShape); +VERIFY_FUNC_REG(FusedMulAddNL2loss, FusedMulAddNL2lossVerify); +// ----------------FusedMulAddNL2loss end------------------- +// ---------------------------------Bias---------------------------------- +IMPLEMT_INFERFUNC(Bias, BiasInferShape) { + OP_LOGI("Bias", "bias infer shape begin---%d", op.GetInputDesc("bias").GetShape().GetDims().size()); + DataType dtype_x = op.GetInputDesc("x").GetDataType(); + ge::Shape shape_x = op.GetInputDesc("x").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector> input_range; + op.GetInputDesc("x").GetShapeRange(input_range); + // set output + TensorDesc output_desc = op.GetOutputDesc("y"); + output_desc.SetShape(shape_x); + output_desc.SetDataType(dtype_x); + output_desc.SetShapeRange(input_range); + (void)op.UpdateOutputDesc("y", output_desc); + + int64_t axis; + int64_t num_axes; + bool bias_from_blob; + if (GRAPH_SUCCESS != op.GetAttr("axis", axis)) { + OP_LOGE("[ERROR] GetOpAttr axis failed!"); + OpsGetAttrErrReport(op.GetName(), "axis"); + return GRAPH_FAILED; + } + if (GRAPH_SUCCESS != op.GetAttr("num_axes", num_axes)) { + OP_LOGE("[ERROR] GetOpAttr num_axes failed!"); + OpsGetAttrErrReport(op.GetName(), "num_axes"); + return GRAPH_FAILED; + } + if (GRAPH_SUCCESS != op.GetAttr("bias_from_blob", bias_from_blob)) { + OP_LOGE("[ERROR] GetOpAttr bias_from_blob failed!"); + OpsGetAttrErrReport(op.GetName(), "bias_from_blob"); + return GRAPH_FAILED; + } + + ge::Shape shape_bias = op.GetInputDesc("bias").GetShape(); + int64_t bias_dim_num = shape_bias.GetDimNum(); + + if (dims_x.size() == 4 && bias_dim_num != 0) { + int64_t length_x = dims_x.size(); + std::vector dims_bias = shape_bias.GetDims(); + int64_t length_bias = dims_bias.size(); + int64_t axis_; + if (axis < 0) { + axis_ = length_x + axis; + } else { + axis_ = axis; + } + + std::vector dims_bias_tmp = shape_bias.GetDims(); + std::vector> range_bias_new; + op.GetInputDesc("bias").GetShapeRange(range_bias_new); + if (bias_from_blob) { + if (num_axes == -1) { + for (int64_t i = 0; i < axis_; i++) { + dims_bias_tmp.insert(dims_bias_tmp.begin(), (int64_t)1); + range_bias_new.insert(range_bias_new.begin(), {1, 1}); + } + } else if (num_axes > 0) { + int64_t left_length = length_x - num_axes - axis_; + for (int64_t i = 0; i < axis_; i++) { + dims_bias_tmp.insert(dims_bias_tmp.begin(), (int64_t)1); + range_bias_new.insert(range_bias_new.begin(), {1, 1}); + } + for (int64_t i = 0; i < left_length; i++) { + dims_bias_tmp.push_back((int64_t)1); + range_bias_new.push_back({1, 1}); + } + } + } else { + int64_t left_length = length_x - length_bias - axis_; + for (int64_t i = 0; i < axis_; i++) { + dims_bias_tmp.insert(dims_bias_tmp.begin(), (int64_t)1); + range_bias_new.insert(range_bias_new.begin(), {1, 1}); + } + for (int64_t i = 0; i < left_length; i++) { + dims_bias_tmp.push_back((int64_t)1); + range_bias_new.push_back({1, 1}); + } + } + + // update bias shape + ge::Shape output_bias_shape = ge::Shape(dims_bias_tmp); + TensorDesc bias_desc = op.GetInputDesc("bias"); + + bias_desc.SetShape(output_bias_shape); + bias_desc.SetOriginShape(output_bias_shape); + bias_desc.SetShapeRange(range_bias_new); + (void)op.UpdateInputDesc("bias", bias_desc); + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(Bias, BiasVerify) { + ge::Shape shape_x = op.GetInputDesc("x").GetShape(); + ge::Shape shape_bias = op.GetInputDesc("bias").GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_bias = shape_bias.GetDims(); + int64_t bias_dim_num = shape_bias.GetDimNum(); + + int64_t axis; + int64_t num_axes; + bool bias_from_blob; + if (GRAPH_SUCCESS != op.GetAttr("axis", axis)) { + OP_LOGE("[ERROR] GetOpAttr axis failed!"); + OpsGetAttrErrReport(op.GetName(), "axis"); + return GRAPH_FAILED; + } + if (GRAPH_SUCCESS != op.GetAttr("num_axes", num_axes)) { + OP_LOGE("[ERROR] GetOpAttr num_axes failed!"); + OpsGetAttrErrReport(op.GetName(), "num_axes"); + return GRAPH_FAILED; + } + if (GRAPH_SUCCESS != op.GetAttr("bias_from_blob", bias_from_blob)) { + OP_LOGE("[ERROR] GetOpAttr bias_from_blob failed!"); + OpsGetAttrErrReport(op.GetName(), "bias_from_blob"); + return GRAPH_FAILED; + } + + int64_t length_x = dims_x.size(); + int64_t length_bias = dims_bias.size(); + + if ((axis >= length_x) || (axis < (-length_x))) { + OP_LOGE("[ERROR] axis out of range index"); + string minvalue = ConcatString(-length_x); + string maxvalue = ConcatString(length_x - 1); + string excepted_value = ConcatString("in the range of [", minvalue,",", maxvalue,"]"); + OpsAttrValueErrReport(op.GetName(), "axis", excepted_value, ConcatString(axis)); + return GRAPH_FAILED; + } + if (num_axes < -1) { + OP_LOGE("[ERROR] num_axes must be non-negative or -1"); + OpsAttrValueErrReport(op.GetName(), "num_axes", "non-negative or -1", ConcatString(num_axes)); + return GRAPH_FAILED; + } + + int64_t axis_; + if (axis < 0) { + axis_ = length_x + axis; + } else { + axis_ = axis; + } + + if (bias_from_blob) { + if (num_axes == -1) { + int64_t bias_num = length_x - axis_; + if (length_bias != bias_num) { + OP_LOGE("[ERROR] length_bias and bias_num must be equal"); + OpsInputShapeErrReport(op.GetName(), "length_bias and bias_num must be equal", + "length_bias", ConcatString(length_bias)); + return GRAPH_FAILED; + } + } else if (num_axes == 0) { + if (bias_dim_num != 0) { + OP_LOGE("[ERROR] bias must be a scalar "); + OpsAttrValueErrReport(op.GetName(), "bias", "scalar", ConcatString(bias_dim_num)); + return GRAPH_FAILED; + } + } else if (num_axes > 0) { + int64_t num_axis = axis_ + num_axes; + if (num_axis > length_x) { + OP_LOGE("[ERROR] bias shape extends x shape when applied"); + OpsOneInputShapeErrReport(op.GetName(), "bias", "Bias shape extends x_shape when applied."); + return GRAPH_FAILED; + } + if (length_bias != num_axes) { + OP_LOGE("[ERROR] length_bias and num_axes must be equal"); + OpsInputShapeErrReport(op.GetName(), "length_bias and bias_num must be equal", + "length_bias", ConcatString(length_bias)); + return GRAPH_FAILED; + } + } + } else { + if (bias_dim_num != 0) { + int64_t bias_num = axis_ + length_bias; + if (bias_num > length_x) { + OP_LOGE("[ERROR] bias shape extends x shape when applied"); + OpsOneInputShapeErrReport(op.GetName(), "bias", "Bias shape extends x_shape when applied"); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Bias, BiasInferShape); +VERIFY_FUNC_REG(Bias, BiasVerify); +// ---------------------------------------Bias----------------------------------------------- + +// ----------------------Threshold------------------------- +IMPLEMT_INFERFUNC(Threshold, ThresholdInferShape) { + TensorDesc tensordesc_input = op.GetInputDesc("x"); + Shape input_shape = tensordesc_input.GetShape(); + DataType input_dtype = tensordesc_input.GetDataType(); + Format input_format = tensordesc_input.GetFormat(); + + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(input_shape); + tensordesc_output.SetDataType(input_dtype); + tensordesc_output.SetFormat(input_format); + + (void)op.UpdateOutputDesc("y", tensordesc_output); + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(Threshold, ThresholdVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(Threshold, ThresholdInferShape); +VERIFY_FUNC_REG(Threshold, ThresholdVerify); +// ---------------------Threshold-------------------------- + +// ------------ConfusionMulGrad Op Begin---------------- +IMPLEMT_COMMON_INFERFUNC(ConfusionMulGradInferShape) { + auto shape = op.GetInputDesc("input0").GetShape(); + auto shape1 = op.GetInputDesc("input1").GetShape(); + DataType input_dtype = op.GetInputDesc("input0").GetDataType(); + DataType input_dtype1 = op.GetInputDesc("input1").GetDataType(); + TensorDesc tensordesc_output = op.GetOutputDesc("output0"); + TensorDesc tensordesc_output1 = op.GetOutputDesc("output1"); + std::vector shapeVector = shape1.GetDims(); + int64_t dimNum = shape1.GetDimNum(); + std::vector axis; + if (ge::GRAPH_SUCCESS != op.GetAttr("axes", axis)) { + OpsGetAttrErrReport(op.GetName(), "axes"); + OP_LOGE(op.GetName().c_str(), + "The input_size op GetOpAttr" + "ConstValue failed!"); + return GRAPH_FAILED; + } + + bool keep_dims; + if (ge::GRAPH_SUCCESS != op.GetAttr("keep_dims", keep_dims)) { + OpsGetAttrErrReport(op.GetName(), "keep_dims"); + OP_LOGE(op.GetName().c_str(), "get keep_dims op GetOpAttr failed!"); + return GRAPH_FAILED; + } + + if (axis.size() == 0) { + for (size_t i = 0; i < shapeVector.size(); ++i) { + axis.push_back(i); + } + } + + for (size_t i = 0; i < axis.size(); ++i) { + if (axis[i] < 0) { + axis[i] = dimNum + axis[i]; + } + } + + std::vector oShapeVector; + std::vector::iterator tmp; + for (int64_t item = 0; item < dimNum; ++item) { + tmp = std::find(axis.begin(), axis.end(), item); + if (tmp != axis.end()) { + // item in axis + // If keepDims is true, current dimesion set to 1 + if (keep_dims == true) { + oShapeVector.push_back(1); + } + } else { + // item is not in ConstValueAxis + oShapeVector.push_back(shapeVector[item]); + } + } + + Shape oShape(oShapeVector); + tensordesc_output1.SetShape(oShape); + tensordesc_output1.SetDataType(input_dtype); + tensordesc_output.SetShape(shape); + tensordesc_output.SetDataType(input_dtype1); + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(ConfusionMulGrad, ConfusionMulGradInferShape); +// ------------ConfusionMulGrad Op End---------------- + +// ------------ArgMaxWithK Op Begin---------------- +IMPLEMT_INFERFUNC(ArgMaxWithK, ArgMaxWithKInfer) { + auto input_dtype = op.get_input_desc_x().GetDataType(); + Shape input_shape = op.get_input_desc_x().GetShape(); + Shape origin_shape = op.get_input_desc_x().GetOriginShape(); + Format input_format = op.get_input_desc_x().GetFormat(); + Format origin_format = op.get_input_desc_x().GetOriginFormat(); + + int axis = op.get_attr_axis(); + int topk = op.get_attr_topk(); + bool out_max_val = op.get_attr_out_max_val(); + bool out_max_index = true; + if (out_max_val && axis != 10000) { + out_max_index = false; + } + + auto output_dtype = input_dtype; + auto output_shape = input_shape; + Format output_format = input_format; + + if (input_format == FORMAT_NC1HWC0) { + if (origin_shape.GetDimNum() == 4) { + if (origin_format == FORMAT_NCHW) { + if (axis < 0) { + axis = axis - 1; + } + } else if (origin_format == FORMAT_NHWC) { + if (axis == -4) { + axis = -5; + } else if (axis == -1) { + axis = -4; + } else if (axis == 1) { + axis = 2; + } else if (axis == 2) { + axis = 3; + } else if (axis == 3) { + axis = 1; + } + } else { + OP_LOGE(op.GetName().c_str(), "5D tensor's origin format should in [NCHW, NHWC]"); + return GRAPH_FAILED; + } + } else { + OP_LOGE(op.GetName().c_str(), "5D tensor's origin shape should be 4D tensor"); + return GRAPH_FAILED; + } + + if (axis < 0) { + axis = axis + 5; + } + if (axis == 10000 || axis == 1 || axis == 4) { + OP_LOGE(op.GetName().c_str(), "5D tensor's axis is invalid"); + return GRAPH_FAILED; + } + } else if (axis < 0) { + axis = axis + input_shape.GetDimNum(); + } + + if (axis == 10000) { + std::vector output_shape_vector; + output_shape_vector.push_back(input_shape.GetDim(0)); + output_shape_vector.push_back(topk); + output_shape = Shape(output_shape_vector); + } else { + output_shape.SetDim(axis, topk); + } + + TensorDesc indicesTensorDesc = TensorDesc(output_shape, output_format, DT_INT32); + indicesTensorDesc.SetRealDimCnt(output_shape.GetDimNum()); + indicesTensorDesc.SetOriginShape(output_shape); + if (!out_max_index) { + indicesTensorDesc.SetDataType(output_dtype); + } + + TensorDesc valuesTensorDesc = TensorDesc(output_shape, output_format, output_dtype); + valuesTensorDesc.SetRealDimCnt(output_shape.GetDimNum()); + valuesTensorDesc.SetOriginShape(output_shape); + + op.update_output_desc_indices(indicesTensorDesc); + op.update_output_desc_values(valuesTensorDesc); + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(ArgMaxWithK, ArgMaxWithKVerify) { + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ArgMaxWithK, ArgMaxWithKInfer); +VERIFY_FUNC_REG(ArgMaxWithK, ArgMaxWithKVerify); +// ------------ArgMaxWithK Op End---------------- + +// ------------Muls Op Begin---------------- +IMPLEMT_VERIFIER(Muls, MulsVerify) { + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(Muls, OneInOneOutCommonInferShape); +VERIFY_FUNC_REG(Muls, MulsVerify); +// ------------Muls Op End---------------- + +// ------------fills Op Start---------------- +bool InferShapeAndTypeFills(Operator& op, const string& x, const string& y, const string& value) { + float value_num; + if (op.GetAttr(value, value_num) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + TensorDesc vOutputDesc = op.GetOutputDesc(y); + + DataType input_dtype = op.GetInputDesc(x).GetDataType(); + Format input_format = op.GetInputDesc(x).GetFormat(); + ge::Shape shapeX = op.GetInputDesc(x).GetShape(); + + vOutputDesc.SetShape(shapeX); + vOutputDesc.SetDataType(input_dtype); + vOutputDesc.SetFormat(input_format); + op.UpdateOutputDesc(y, vOutputDesc); + + return true; +} +// ----------------Add------------------- +IMPLEMT_VERIFIER(Fills, FillsVerify) { + return GRAPH_SUCCESS; +} +// Obtains the processing function of the output tensor description. +IMPLEMT_COMMON_INFERFUNC(FillsInferShape) { + if (InferShapeAndTypeFills(op, "x", "y", "value")) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +// Registered inferfunction +COMMON_INFER_FUNC_REG(Fills, FillsInferShape); +// Registered verify function +VERIFY_FUNC_REG(Fills, FillsVerify); +// -----------fills Op End---------------- + +// --------------MulNoNan +IMPLEMT_VERIFIER(MulNoNan, MulNoNanVerify) { + DataType input_type_x1 = op.GetInputDesc("x1").GetDataType(); + DataType input_type_x2 = op.GetInputDesc("x2").GetDataType(); + if (input_type_x1 != input_type_x2) { + OP_LOGE(op.GetName().c_str(), + "The %s op dtype is not same, type1:%d, type2:%d", + op.GetName().c_str(), input_type_x1, input_type_x2); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(MulNoNan, MulNoNanVerify); + +IMPLEMT_COMMON_INFERFUNC(MulNoNanInferShape) { + bool is_dynamic_output = true; + if(InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", + is_dynamic_output)) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(MulNoNan, MulNoNanInferShape); +// ------------MulNoNan END + +// ----------------------Axpy-------------------------- +IMPLEMT_VERIFIER(Axpy, AxpyVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AxpyInferShape) { + bool is_dynamic_output = true; + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Axpy, AxpyInferShape); +VERIFY_FUNC_REG(Axpy, AxpyVerify); +// ---------------------Axpy END------------------------ + +// ------------CosineEmbeddingLoss Op Begin---------------- +IMPLEMT_VERIFIER(CosineEmbeddingLoss, CosineEmbeddingLossVerify) { + Shape shape_x1 = op.GetInputDesc("x1").GetShape(); + Shape shape_x2 = op.GetInputDesc("x2").GetShape(); + if ((shape_x1.GetDimNum() < 2) && (shape_x2.GetDimNum() < 2)) { + OP_LOGE(op.GetName().c_str(), "input x1 or x2 dims must bigger than 1"); + return GRAPH_FAILED; + } + + std::string reduction; + op.GetAttr("reduction", reduction); + if ((reduction != "mean") && (reduction != "sum") && (reduction != "none")) { + OP_LOGE(op.GetName().c_str(), "reduction only support \"mean\", \"sum\" and \"none\""); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(CosineEmbeddingLossInferShape) { + Shape shape_x1 = op.GetInputDesc("x1").GetShape(); + Shape shape_x2 = op.GetInputDesc("x2").GetShape(); + Shape shape_tgt = op.GetInputDesc("target").GetShape(); + + vector x_dims_broadcast; + vector tgt_dims_broadcast; + + if (!BroadCastTwoShape(op, shape_x1, shape_x2, x_dims_broadcast)) { + OP_LOGE(op.GetName().c_str(), "input x1 and x2 shape can't broadcast"); + return GRAPH_FAILED; + } + + // reduce aixs = 1 + x_dims_broadcast.erase(x_dims_broadcast.begin() + 1); + + Shape shape_x_broadcast(x_dims_broadcast); + if (!BroadCastTwoShape(op, shape_x_broadcast, shape_tgt, tgt_dims_broadcast)) { + OP_LOGE(op.GetName().c_str(), "input target shape can't broadcast to x shape"); + return GRAPH_FAILED; + } + + float margin = 0.0; + std::string reduction; + (void)op.GetAttr("margin", margin); + (void)op.GetAttr("reduction", reduction); + OP_LOGI(op.GetName().c_str(), "setting margin:%f, reduction:%s\n", margin, reduction.c_str()); + + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + Shape y_shape = Shape(tgt_dims_broadcast); + if ((reduction == "mean") || (reduction == "sum")) { + tensordesc_output.SetShape(Shape({1})); + } else if (reduction == "none") { + tensordesc_output.SetShape(y_shape); + } + + tensordesc_output.SetDataType(DT_FLOAT); + tensordesc_output.SetFormat(FORMAT_ND); + (void)op.UpdateOutputDesc("y", tensordesc_output); + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(CosineEmbeddingLoss, CosineEmbeddingLossInferShape); +VERIFY_FUNC_REG(CosineEmbeddingLoss, CosineEmbeddingLossVerify); +// ------------CosineEmbeddingLoss Op End---------------- + +// ----------------------KLDiv-------------------------- +IMPLEMT_VERIFIER(KLDiv, KLDivVerify) { + if (!CheckInputsShapeDtypeSame(op, {"x", "target"})) { + return GRAPH_FAILED; + } + std::vector const_attr; + if (!GetConstAttr(op, {"reduction"}, const_attr)) { + OP_LOGE(op.GetName().c_str(), "The GetOpAttr ConstValue failed!"); + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(KLDivInferShape) { + + // get input desc + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + + auto x_desc = op_info->MutableInputDesc("x"); + auto x_dtype = x_desc->GetDataType(); + std::vector x_dims; + + auto y_desc = op_info->MutableOutputDesc("y"); + + y_desc->SetShape(GeShape(x_dims)); + y_desc->SetDataType(x_dtype); + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(KLDiv, KLDivInferShape); +VERIFY_FUNC_REG(KLDiv, KLDivVerify); +// ---------------------KLDiv End------------------------ + +// ----------------TensorMove Begin------------------- +COMMON_INFER_FUNC_REG(TensorMove, ELMTWISE_INFER_SHAPEANDTYPE("x", "y")); +// ----------------TensorMove END--------------------- + +// ----------------TensorRedirect Begin------------------- +COMMON_INFER_FUNC_REG(TensorRedirect, ELMTWISE_INFER_SHAPEANDTYPE("x", "output_x")); +// --------------TensorRedirect END----------------------- + +// ----------------MaxN Begin------------------- +static bool MaxNCheckDtype(const ge::Operator& op) { + int32_t input_num = op.GetInputsSize(); + if (input_num <= 0) { + OP_LOGE("MaxNInferShape", "DynamicInputNum is le 0"); + return false; + } + ge::TensorDesc input_desc0 = op.GetDynamicInputDesc("x", 0); + DataType data_ty0 = input_desc0.GetDataType(); + for (int i = 1; i < input_num; ++i) { + ge::TensorDesc input_desc = op.GetDynamicInputDesc("x", i); + DataType data_ty = input_desc.GetDataType(); + if (data_ty0 != data_ty) { + OP_LOGE("MaxNInferShape", "DynamicInput DataType is not equal"); + return false; + } + } + return true; +} + +static void MaxNUpdateInferShape(std::vector& dims, + const ge::Shape input_shape) { + int32_t dims_size = dims.size(); + std::vector input_dims = input_shape.GetDims(); + int32_t input_dims_size = input_dims.size(); + if (input_dims_size > dims_size) { + for (int i = 0; i < input_dims_size - dims_size; ++i) { + dims.insert(dims.begin(), 0); + } + dims_size = dims.size(); + } + int32_t i = dims_size - input_dims_size; + int32_t j = 0; + while (i < dims_size && j < input_dims_size) { + if (dims[i] < input_dims[j]) { + dims[i] = input_dims[j]; + } + i++; + j++; + } +} +IMPLEMT_COMMON_INFERFUNC(MaxNInferShape) { + std::vector dims(1, 0); + int32_t input_num = op.GetInputsSize(); + if (input_num <= 0) { + OP_LOGE("MaxNInferShape", "DynamicInputNum is le 0"); + return GRAPH_FAILED; + } + for (int i = 0; i < input_num; ++i) { + ge::TensorDesc input_desc = op.GetDynamicInputDesc("x", i); + ge::Shape input_shape = input_desc.GetShape(); + MaxNUpdateInferShape(dims, input_shape); + } + ge::TensorDesc input_desc0 = op.GetDynamicInputDesc("x", 0); + ge::TensorDesc output_desc = op.GetOutputDesc("y"); + ge::Shape inferShape(dims); + output_desc.SetShape(inferShape); + output_desc.SetDataType(input_desc0.GetDataType()); + output_desc.SetFormat(input_desc0.GetFormat()); + op.UpdateOutputDesc("y", output_desc); + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(MaxN, MaxNVerify) { + if (!MaxNCheckDtype(op)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +INFER_FUNC_REG(MaxN, MaxNInferShape); +VERIFY_FUNC_REG(MaxN, MaxNVerify); +// ----------------MaxN END--------------------- + +// ----------------TensorEqual Begin------------------- +bool InferShapeAndTypeTensorEqual(Operator &op, const string &input_name1, + const string &input_name2, + const string &output_name) { + TensorDesc v_output_desc = op.GetOutputDesc(output_name); + + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + + if (shape_x.GetShapeSize() != shape_y.GetShapeSize()) { + OP_LOGE("The ShapeSize of input_x does not match input_y."); + return false; + } + return true; + + std::vector dim_vec = {1}; + ge::Shape output_shape = ge::Shape(dim_vec); + v_output_desc.SetShape(output_shape); + v_output_desc.SetDataType(DT_BOOL); + v_output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, v_output_desc); + + return true; +} + +IMPLEMT_COMMON_INFERFUNC(TensorEqualInferShape) { + if (InferShapeAndTypeTensorEqual(op, "input_x", "input_y", "output_z")) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +IMPLEMT_VERIFIER(TensorEqual, TensorEqualVerify) { + // Check whether the data types of two input tensors are the same. + if (op.GetInputDesc("input_x").GetDataType() != + op.GetInputDesc("input_y").GetDataType()) { + OP_LOGE("input_x input_y tensor dtype does not match."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(TensorEqual, TensorEqualInferShape); +VERIFY_FUNC_REG(TensorEqual, TensorEqualVerify); +// ----------------TensorEqual END--------------------- + +void CompareBothShape(std::vector& dims_fst, + std::vector& dims_sec) { + if (dims_fst.size() < dims_sec.size()) { + std::vector dims_tmp = dims_fst; + dims_fst = dims_sec; + dims_sec = dims_tmp; + } + + if (dims_fst.size() > dims_sec.size()) { + int dec = dims_fst.size() - dims_sec.size(); + dims_sec.insert(dims_sec.begin(), dec, (int64_t)1); + } +} + +graphStatus ChangeShape(std::vector& dims_fst, + std::vector& dims_sec, + std::vector& dims_vec) { + CompareBothShape(dims_fst, dims_sec); + // calculate shape of output: shape[i] = max(dims_fst[i], dims_sec[i]) + for (size_t i = 0; i < dims_fst.size(); i++) { + if ((dims_fst[i] != dims_sec[i]) && (dims_fst[i] != 1) && + (dims_sec[i] != 1)) { + OP_LOGE("[ERROR] dims_fst and dims_sec can not be broadcast"); + return GRAPH_FAILED; + } + + int64_t dims = (dims_fst[i] > dims_sec[i]) ? dims_fst[i] : dims_sec[i]; + dims_vec.push_back(dims); + } + return GRAPH_SUCCESS; +} + +graphStatus ReplenishShape(std::vector& dims_x, + std::vector& dims_y, + std::vector& dims_z, + std::vector& dims_vec) { + std::vector dims_vec1; + if (ChangeShape(dims_x, dims_y, dims_vec1) == GRAPH_FAILED) { + return GRAPH_FAILED; + } + + if (ChangeShape(dims_vec1, dims_z, dims_vec) == GRAPH_FAILED) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus InferShapeAndTypeAddcdivAndAddcmul(Operator& op, + const string& input_name1, + const string& input_name2, + const string& input_name3, + const string& output_name) { + TensorDesc v_output_desc = op.GetOutputDesc(output_name); + + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + ge::Shape shape_x = op.GetInputDesc(input_name3).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + ge::Shape shape_z = op.GetInputDesc(input_name1).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + std::vector dims_z = shape_z.GetDims(); + if (dims_x.size() < dims_y.size()) { + std::vector dims_tmp = dims_x; + dims_x = dims_y; + dims_y = dims_tmp; + } + + std::vector dims_vec; + if (ReplenishShape(dims_x, dims_y, dims_z, dims_vec) == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "ReplenishShape run failed"); + return GRAPH_FAILED; + } + + ge::Shape output_shape = ge::Shape(dims_vec); + + v_output_desc.SetShape(output_shape); + v_output_desc.SetDataType(input_dtype); + v_output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, v_output_desc); + + return GRAPH_SUCCESS; +} + +// ----------------Addcdiv begin------------------- +IMPLEMT_VERIFIER(Addcdiv, AddcdivVerify) { + // the data type of input_data, x1 and x2 should be same. + if (op.GetInputDesc("x2").GetDataType() != + op.GetInputDesc("input_data").GetDataType() || + op.GetInputDesc("x1").GetDataType() != + op.GetInputDesc("input_data").GetDataType()) { + OP_LOGE(op.GetName().c_str(), + "input_data data type and x1, x2 match failed"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +// Obtains the processing function of the output tensor description. +IMPLEMT_COMMON_INFERFUNC(AddcdivInferShape) { + if (InferShapeAndTypeAddcdivAndAddcmul(op, "input_data", "x1", "x2", "y") == + GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +// Registered inferfunction +COMMON_INFER_FUNC_REG(Addcdiv, AddcdivInferShape); +// Registered verify function +VERIFY_FUNC_REG(Addcdiv, AddcdivVerify); +// ----------------Addcdiv end------------------- + +// ----------------Addcmul begin------------------- +IMPLEMT_VERIFIER(Addcmul, AddcmulVerify) { + // the data type of input_data,x1 and x2 should be same. + if (op.GetInputDesc("input_data").GetDataType() != + op.GetInputDesc("x1").GetDataType() || + op.GetInputDesc("input_data").GetDataType() != + op.GetInputDesc("x2").GetDataType()) { + OP_LOGE(op.GetName().c_str(), + "input_data data type and x1,x2 match failed"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +// Obtains the processing function of the output tensor description. +IMPLEMT_COMMON_INFERFUNC(AddcmulInferShape) { + if (InferShapeAndTypeAddcdivAndAddcmul(op, "input_data", "x1", "x2", "y") == + GRAPH_SUCCESS) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +// Registered inferfunction +COMMON_INFER_FUNC_REG(Addcmul, AddcmulInferShape); +// Registered verify function +VERIFY_FUNC_REG(Addcmul, AddcmulVerify); +// ----------------Addcmul end------------------- + +// ----------------AxpyV2 Begin------------------- +IMPLEMT_VERIFIER(AxpyV2, AxpyV2Verify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(AxpyV2InferShape) { + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(AxpyV2, AxpyV2InferShape); +VERIFY_FUNC_REG(AxpyV2, AxpyV2Verify); +// ----------------AxpyV2 END--------------------- + +// ----------------PtAdd Begin------------------- +bool InferShapeAndTypePtAdd(Operator& op, const string& input_name1, + const string& input_name2, + const string& output_name) { + TensorDesc output_desc = op.GetOutputDesc(output_name); + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + // The size of the shape dimension is exchanged. + // Each dimension of dims_x uses the larger value of the corresponding + // dimension in two tensors. + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() < dims_y.size()) { + dims_x.swap(dims_y); + } + + // The small shape is padded with 1. + if (dims_x.size() != dims_y.size()) { + int dec = dims_x.size() - dims_y.size(); + for (int i = 0; i < dec; i++) { + dims_y.insert(dims_y.begin(), (int64_t)1); + } + } + + // The value of each dimension in the shape of the output tensor is the + // larger value of the corresponding dimension in the two inputs. + std::vector dim_vec; + for (size_t i = 0; i < dims_x.size(); i++) { + if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { + OP_LOGE("The shape of x1 and x2 can not broadcast."); + return false; + } + + int64_t dims = (dims_x[i] > dims_y[i]) ? dims_x[i] : dims_y[i]; + dim_vec.push_back(dims); + } + ge::Shape output_shape = ge::Shape(dim_vec); + + output_desc.SetShape(output_shape); + output_desc.SetDataType(input_dtype); + output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, output_desc); + + return true; +} + +IMPLEMT_VERIFIER(PtAdd, PtAddVerify) { + // Check whether the data types of two input tensors are the same. + if (op.GetInputDesc("x1").GetDataType() != + op.GetInputDesc("x2").GetDataType()) { + OP_LOGE("x1 x2 tensor dtype does not match."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +// Obtains the processing function of the output tensor description. +IMPLEMT_COMMON_INFERFUNC(PtAddInferShape) { + // Check whether the data shape of two input tensors are the same. + if (InferShapeAndTypePtAdd(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + OP_LOGE("The shape of output y does not match that of x1 x2."); + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(PtAdd, PtAddInferShape); +VERIFY_FUNC_REG(PtAdd, PtAddVerify); +// ----------------PtAdd END--------------------- + +// ----------------PtMuls Begin------------------- +bool InferShapeAndTypePtMuls(Operator &op, const string &input_name1, + const string &input_name2, + const string &output_name) { + TensorDesc v_output_desc = op.GetOutputDesc(output_name); + + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + // The size of the shape dimension is exchanged. + // Each dimension of dims_x uses the larger value of the corresponding + // dimension in two tensors. + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() < dims_y.size()) { + dims_x.swap(dims_y); + } + + // The small shape is padded with 1. + if (dims_x.size() != dims_y.size()) { + int dec = dims_x.size() - dims_y.size(); + for (int i = 0; i < dec; i++) { + dims_y.insert(dims_y.begin(), (int64_t)1); + } + } + + // The value of each dimension in the shape of the output tensor is the + // larger value of the corresponding dimension in the two inputs. + std::vector dim_vec; + for (size_t i = 0; i < dims_x.size(); i++) { + if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { + return false; + } + + int64_t dims = (dims_x[i] > dims_y[i]) ? dims_x[i] : dims_y[i]; + dim_vec.push_back(dims); + } + ge::Shape output_shape = ge::Shape(dim_vec); + + v_output_desc.SetShape(output_shape); + v_output_desc.SetDataType(input_dtype); + v_output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, v_output_desc); + + return true; +} + +IMPLEMT_COMMON_INFERFUNC(PtMulsInferShape) { + if (InferShapeAndTypePtMuls(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + OP_LOGE("The shape of output y does not match that of x1 x2."); + return GRAPH_FAILED; +} + +IMPLEMT_VERIFIER(PtMuls, PtMulsVerify) { + // Check whether the data types of two input tensors are the same. + if (op.GetInputDesc("x1").GetDataType() != + op.GetInputDesc("x2").GetDataType()) { + OP_LOGE("x1 x2 tensor dtype does not match."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(PtMuls, PtMulsInferShape); +VERIFY_FUNC_REG(PtMuls, PtMulsVerify); +// ----------------PtMuls END--------------------- + +// ----------------PtSub Begin------------------- +bool InferShapeAndTypePtSub(Operator& op, const string& input_name1, + const string& input_name2, + const string& output_name) { + TensorDesc output_desc = op.GetOutputDesc(output_name); + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + // The size of the shape dimension is exchanged. + // Each dimension of dims_x uses the larger value of the corresponding + // dimension in two tensors. + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + if (dims_x.size() < dims_y.size()) { + dims_x.swap(dims_y); + } + + // The small shape is padded with 1. + if (dims_x.size() != dims_y.size()) { + int dec = dims_x.size() - dims_y.size(); + for (int i = 0; i < dec; i++) { + dims_y.insert(dims_y.begin(), (int64_t)1); + } + } + + // The value of each dimension in the shape of the output tensor is the + // larger value of the corresponding dimension in the two inputs. + std::vector dim_vec; + for (size_t i = 0; i < dims_x.size(); i++) { + if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { + OP_LOGE("The shape of x1 and x2 can not broadcast."); + return false; + } + + int64_t dims = (dims_x[i] > dims_y[i]) ? dims_x[i] : dims_y[i]; + dim_vec.push_back(dims); + } + ge::Shape output_shape = ge::Shape(dim_vec); + + output_desc.SetShape(output_shape); + output_desc.SetDataType(input_dtype); + output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, output_desc); + + return true; +} + +IMPLEMT_VERIFIER(PtSub, PtSubVerify) { + // Check whether the data types of two input tensors are the same. + if (op.GetInputDesc("x1").GetDataType() != + op.GetInputDesc("x2").GetDataType()) { + OP_LOGE("x1 x2 tensor dtype does not match."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +// Obtains the processing function of the output tensor description. +IMPLEMT_COMMON_INFERFUNC(PtSubInferShape) { + // Check whether the data shape of two input tensors are the same. + if (InferShapeAndTypePtSub(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + OP_LOGE("The shape of output y does not match that of x1 x2."); + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(PtSub, PtSubInferShape); +VERIFY_FUNC_REG(PtSub, PtSubVerify); +// ----------------PtSub END--------------------- + +// ----------------StrideAdd Begin------------------- +bool InferShapeAndTypeStrideAdd(Operator &op, const string &input_name1, + const string &input_name2, + const string &outputName) { + TensorDesc output_desc = op.GetOutputDesc(outputName); + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + std::vector dims_x = shape_x.GetDims(); // (N, x1_C1, H, W, C0) + + int64_t c1_len = 0; + + op.GetAttr("c1_len", c1_len); // (N, c1_len, H, W, C0) + dims_x[1] = c1_len; + ge::Shape output_shape = ge::Shape(dims_x); + + output_desc.SetShape(output_shape); + output_desc.SetDataType(input_dtype); + output_desc.SetFormat(input_format); + op.UpdateOutputDesc(outputName, output_desc); + + return true; +} + +IMPLEMT_VERIFIER(StrideAdd, StrideAddVerify) { return GRAPH_SUCCESS; } + +// Obtains the processing function of the output tensor description +IMPLEMT_COMMON_INFERFUNC(StrideAddInferShape) { + if (InferShapeAndTypeStrideAdd(op, "x1", "x2", "y")) { + return GRAPH_SUCCESS; + } + OP_LOGE(op.GetName().c_str(), "IMPLEMT_COMMON_INFERFUNC FAILED."); + return GRAPH_FAILED; +} + +// Registered inferfunction +COMMON_INFER_FUNC_REG(StrideAdd, StrideAddInferShape); +// Registered verify function +VERIFY_FUNC_REG(StrideAdd, StrideAddVerify); +// ----------------StrideAdd END--------------------- + +// ----------------MaskedScale Begin------------------- +bool VerifyMaskedScaleShapeAndType(Operator &op, DataType x_dtype, DataType mask_dtype) +{ + if ((x_dtype != DT_FLOAT) && (x_dtype != DT_FLOAT16)) { + OP_LOGE(op.GetName().c_str(), "The input dtype of x is invalid, please check!"); + return false; + } + + if ((mask_dtype != DT_INT8) && (mask_dtype != DT_FLOAT) && (mask_dtype != DT_FLOAT16)) { + OP_LOGE(op.GetName().c_str(), "The input dtype of mask is invalid, please check!"); + return false; + } + + return true; +} + +IMPLEMT_VERIFIER(MaskedScale, MaskedScaleVerify) { + TensorDesc x_tensordesc = op.GetInputDesc("x"); + DataType x_dtype = x_tensordesc.GetDataType(); + TensorDesc mask_tensordesc = op.GetInputDesc("mask"); + DataType mask_dtype = mask_tensordesc.GetDataType(); + + if (false == VerifyMaskedScaleShapeAndType(op, x_dtype, mask_dtype)) { + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(MaskedScaleInferShape) { + OP_LOGI("MaskedScale", "infer shape begin"); + TensorDesc tensordesc_input = op.GetInputDesc("x"); + ge::Shape input_shape = tensordesc_input.GetShape(); + DataType input_dtype = tensordesc_input.GetDataType(); + + TensorDesc mask_tensordesc = op.GetInputDesc("mask"); + ge::Shape mask_shape = mask_tensordesc.GetShape(); + DataType mask_dtype = mask_tensordesc.GetDataType(); + + if (input_shape.GetShapeSize() != mask_shape.GetShapeSize()) { + OP_LOGE(op.GetName().c_str(), "shapesize of x not match mask"); + return GRAPH_FAILED; + } + + TensorDesc tensordesc_output = op.GetOutputDesc("y"); + tensordesc_output.SetShape(input_shape); + tensordesc_output.SetDataType(input_dtype); + (void)op.UpdateOutputDesc("y", tensordesc_output); + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(MaskedScale, MaskedScaleInferShape); +VERIFY_FUNC_REG(MaskedScale, MaskedScaleVerify); +// ----------------MaskedScale END----------- + +// ----------------AbsGrad------------------- +IMPLEMT_VERIFIER(AbsGrad, AbsGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(AbsGrad, AbsGradVerify); + +IMPLEMT_COMMON_INFERFUNC(AbsGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(AbsGrad, AbsGradInferShape); +// --------------AbsGrad END---------------- + +// ----------------Acosh-------------------- +IMPLEMT_COMMON_INFERFUNC(AcoshInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Acosh, AcoshInferShape); +// --------------Acosh END----------------- + +// ------------Adds------------------------ +IMPLEMT_VERIFIER(Adds, AddsVerify) { + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(Adds, AddsVerify); + +IMPLEMT_COMMON_INFERFUNC(AddsInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter AddsInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Adds, AddsInferShape); +// ------------Adds Op End----------------- + +// ----------------Asin-------------------- +IMPLEMT_COMMON_INFERFUNC(AsinInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Asin, AsinInferShape); +// --------------Asin END----------------- + +// ----------------AsinGrad--------------- +IMPLEMT_VERIFIER(AsinGrad, AsinGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(AsinGrad, AsinGradVerify); + +IMPLEMT_COMMON_INFERFUNC(AsinGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(AsinGrad, AsinGradInferShape); +// --------------AsinGrad END------------- + +// ----------------Ceil------------------- +IMPLEMT_COMMON_INFERFUNC(CeilInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter CeilInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Ceil, CeilInferShape); +// --------------Ceil END---------------- + +// ----------------Cos------------------- +IMPLEMT_COMMON_INFERFUNC(CosInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter CosInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Cos, CosInferShape); +// --------------Cos END------------------ + +// ----------------Cosh------------------- +IMPLEMT_COMMON_INFERFUNC(CoshInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter CoshInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Cosh, CoshInferShape); +// ---------------Cosh END---------------- + +// ----------------Sin-------------------- +IMPLEMT_COMMON_INFERFUNC(SinInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter SinInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Sin, SinInferShape); +// ---------------Sin END----------------- + +// ----------------Sinh------------------- +IMPLEMT_COMMON_INFERFUNC(SinhInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter SinhInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Sinh, SinhInferShape); +// ---------------Sinh END---------------- + +// ---------------Tan--------------------- +IMPLEMT_COMMON_INFERFUNC(TanInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter TanInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Tan, TanInferShape); +// --------------------Tan Op End---------- + +// ----------------Lerp Begin------------------- +bool InferShapeAndTypeLerp(Operator& op, + const string& input_name1, const string& input_name2, + const string& input_name3, const string& output_name) { + TensorDesc v_output_desc = op.GetOutputDesc(output_name); + + DataType input_dtype = op.GetInputDesc(input_name1).GetDataType(); + Format input_format = op.GetInputDesc(input_name1).GetFormat(); + + ge::Shape shape_x = op.GetInputDesc(input_name1).GetShape(); + ge::Shape shape_y = op.GetInputDesc(input_name2).GetShape(); + ge::Shape shape_z = op.GetInputDesc(input_name3).GetShape(); + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + std::vector dims_z = shape_z.GetDims(); + if (dims_x.size() < dims_y.size()) { + std::vector dims_tmp = dims_x; + dims_x = dims_y; + dims_y = dims_tmp; + } + if (dims_x.size() < dims_z.size()) { + std::vector dims_tmp = dims_x; + dims_x = dims_z; + dims_z = dims_tmp; + } + + if (dims_x.size() != dims_y.size()) { + int dec = dims_x.size() - dims_y.size(); + for (int i = 0; i < dec; i++) { + dims_y.insert(dims_y.begin(), (int64_t)1); + } + } + if (dims_x.size() != dims_z.size()) { + int dec = dims_x.size() - dims_z.size(); + for (int i = 0; i < dec; i++) { + dims_z.insert(dims_z.begin(), (int64_t)1); + } + } + + std::vector dim_vec; + for (size_t i = 0; i < dims_x.size(); i++) { + if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { + OP_LOGE(op.GetName().c_str(), "Input shapes are not compatible."); + return false; + } + if ((dims_x[i] != dims_z[i]) && (dims_x[i] != 1) && (dims_z[i] != 1)) { + OP_LOGE(op.GetName().c_str(), "Input shapes are not compatible."); + return false; + } + int64_t dims_tmp = dims_x[i] > dims_y[i] ? dims_x[i] : dims_y[i]; + int64_t dims = dims_tmp > dims_z[i] ? dims_tmp : dims_z[i]; + dim_vec.push_back(dims); + } + ge::Shape output_shape = ge::Shape(dim_vec); + + v_output_desc.SetShape(output_shape); + v_output_desc.SetDataType(input_dtype); + v_output_desc.SetFormat(input_format); + op.UpdateOutputDesc(output_name, v_output_desc); + + return true; +} + +IMPLEMT_VERIFIER(Lerp, LerpVerify) { + DataType start_type = op.GetInputDesc("start").GetDataType(); + DataType end_type = op.GetInputDesc("end").GetDataType(); + DataType weight_type = op.GetInputDesc("weight").GetDataType(); + if (start_type != end_type) { + OP_LOGE(op.GetName().c_str(), "Input dtypes are not the same."); + return GRAPH_FAILED; + } + if (start_type != weight_type) { + OP_LOGE(op.GetName().c_str(), "Input dtypes are not the same."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(LerpInferShape) { + if (InferShapeAndTypeLerp(op, "start", "end", "weight", "y")) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Lerp, LerpInferShape); +VERIFY_FUNC_REG(Lerp, LerpVerify); +// ----------------Lerp END--------------------- +// ----------------Asinh------------------- +IMPLEMT_COMMON_INFERFUNC(AsinhInferShape) { + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Asinh, AsinhInferShape); +// --------------Asinh END----------------- + +// ------------------Mod-------------------- +IMPLEMT_VERIFIER(Mod, ModVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(Mod, ModVerify); + +IMPLEMT_COMMON_INFERFUNC(ModInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Mod, ModInferShape); +// ----------------Mod END--------------- + +// --------------Xdivy------------------- +IMPLEMT_VERIFIER(Xdivy, XdivyVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(XdivyInferShape) { + bool is_dynamic_output = true; + if (InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +COMMON_INFER_FUNC_REG(Xdivy, XdivyInferShape); +VERIFY_FUNC_REG(Xdivy, XdivyVerify); +// ------------Xdivy END----------------- + +// ------------Xlogy--------------------- +IMPLEMT_VERIFIER(Xlogy, XlogyVerify) { + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(XlogyInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Xlogy, XlogyInferShape); +VERIFY_FUNC_REG(Xlogy, XlogyVerify); +// ------------Xlogy END------------------ + +// ----------------AsinhGrad------------------- +IMPLEMT_VERIFIER(AsinhGrad, AsinhGradVerify) { + if (!CheckTwoInputDtypeSame(op, "y", "dy")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +VERIFY_FUNC_REG(AsinhGrad, AsinhGradVerify); +IMPLEMT_COMMON_INFERFUNC(AsinhGradInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "y", "dy", "z", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(AsinhGrad, AsinhGradInferShape); +// --------------AsinhGrad END----------------- + +// -----------------TruncateDiv-------------------- +IMPLEMT_COMMON_INFERFUNC(TruncateDivInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(TruncateDiv, TruncateDivInferShape); +// -----------------TruncateDiv END---------------- + +// ----------------TruncateMod--------------------- +IMPLEMT_COMMON_INFERFUNC(TruncateModInferShape) { + bool is_dynamic_output = true; + if (!InferShapeAndTypeTwoInOneOutBroadcast(op, "x1", "x2", "y", is_dynamic_output)) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(TruncateMod, TruncateModInferShape); +// ----------------TruncateMod END----------------- + +// ----------------Floor--------------------- +IMPLEMT_COMMON_INFERFUNC(FloorInferShape) { + OP_LOGI(op.GetName().c_str(), "Enter FloorInferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Floor, FloorInferShape); +// ----------------Floor END----------------- + +// ----------------Expm1--------------------- +IMPLEMT_COMMON_INFERFUNC(Expm1InferShape) { + OP_LOGI(op.GetName().c_str(), "Enter Expm1InferShape"); + if (OneInOneOutDynamicInfer(op, "x", {"y"})) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} +COMMON_INFER_FUNC_REG(Expm1, Expm1InferShape); +// ----------------Expm1 END----------------- + + +// -------------------DataCompare---------------------- +IMPLEMT_VERIFIER(DataCompare, DataCompareVerify) { + float atol_data; + if (ge::GRAPH_SUCCESS != op.GetAttr("atol", atol_data)) { + OpsGetAttrErrReport(op.GetName(), "atol"); + OP_LOGE(op.GetName().c_str(), "GetOpAttr failed of DataCompare!"); + return GRAPH_FAILED; + } + if (atol_data < 0) { + OpsAttrValueErrReport(op.GetName(), "atol", ">= 0", ConcatString(atol_data)); + OP_LOGE(op.GetName().c_str(), "atol should >= 0!"); + return GRAPH_FAILED; + } + + float rtol_data; + if (ge::GRAPH_SUCCESS != op.GetAttr("rtol", rtol_data)) { + OpsGetAttrErrReport(op.GetName(), "rtol"); + OP_LOGE(op.GetName().c_str(), "GetOpAttr failed of DataCompare!"); + return GRAPH_FAILED; + } + if (rtol_data < 0) { + OpsAttrValueErrReport(op.GetName(), "rtol", ">= 0", ConcatString(rtol_data)); + OP_LOGE(op.GetName().c_str(), "rtol should >= 0!"); + return GRAPH_FAILED; + } + + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(DataCompareInferShape) { + TensorDesc tensordesc_num = op.GetOutputDesc("num"); + TensorDesc tensordesc_diff = op.GetOutputDesc("diff"); + + std::vector oShapeVector; + Shape oShape(oShapeVector); + + tensordesc_num.SetShape(ge::Shape(oShape)); + tensordesc_num.SetDataType(DT_INT32); + (void)op.UpdateOutputDesc("num", tensordesc_num); + + tensordesc_diff.SetShape(ge::Shape(oShape)); + tensordesc_diff.SetDataType(DT_FLOAT16); + (void)op.UpdateOutputDesc("diff", tensordesc_diff); + + + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(DataCompare, DataCompareInferShape); +VERIFY_FUNC_REG(DataCompare, DataCompareVerify); +// -------------------DataCompare------------------------- + +// ---------------HardMax Begin----------------- +IMPLEMT_COMMON_INFERFUNC(HardMaxInferShape) +{ + ge::TensorDesc input_desc = op.GetInputDesc(0); + ge::TensorDesc output_desc = op.GetOutputDesc(0); + output_desc.SetShape(input_desc.GetShape()); + output_desc.SetFormat(input_desc.GetFormat()); + output_desc.SetDataType(input_desc.GetDataType()); + op.UpdateOutputDesc("y", output_desc); + return GRAPH_SUCCESS; +} + +IMPLEMT_VERIFIER(HardMax, HardMaxVerify) +{ + int dimension = -1; + auto ret = op.GetAttr("axis", dimension); + if (ret != ge::GRAPH_SUCCESS) { + OP_LOGE("HardMaxVerify", "OP GetAttr axis fail."); + return GRAPH_FAILED; + } + ge::TensorDesc input_desc = op.GetInputDesc(0); + ge::DataType data_type = input_desc.GetDataType(); + if (data_type != DT_FLOAT16 && data_type != DT_FLOAT) { + OP_LOGE("HardMaxVerify", "Input DataType is not fp16 or fp32"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} +INFER_FUNC_REG(HardMax, HardMaxInferShape); +VERIFY_FUNC_REG(HardMax, HardMaxVerify); +// ---------------HardMax END------------------- + +// ---------------Dot Begin----------------- +bool InferShapeDot(Operator& op, + const string& input1_name, + const string input2_name, + const string output_name) { + TensorDesc output_desc = op.GetOutputDesc(output_name); + TensorDesc input1_desc = op.GetInputDesc(input1_name); + TensorDesc input2_desc = op.GetInputDesc(input2_name); + + //input dim + ge::Shape shape_input1 = input1_desc.GetShape(); + ge::Shape shape_input2 = input2_desc.GetShape(); + + std::vector dims_input1 = shape_input1.GetDims(); + std::vector dims_input2 = shape_input2.GetDims(); + + if(dims_input1.size() != dims_input2.size()) { + OP_LOGE("The dim of input_x and input_y not match."); + return false; + } + + if(dims_input1.size() != 1) { + OP_LOGE("The dim of input must be 1"); + return false; + } + + if(dims_input1[0] != dims_input2[0]) { + OP_LOGE("The 0-dim of input_x and input_y not match."); + return false; + } + + std::vector dim_output; + dim_output.push_back(1); + + ge::Shape output_shape = ge::Shape(dim_output); + + output_desc.SetShape(output_shape); + output_desc.SetDataType(input1_desc.GetDataType()); + output_desc.SetFormat(Format::FORMAT_ND); + op.UpdateOutputDesc(output_name, output_desc); + return true; +} + + +IMPLEMT_COMMON_INFERFUNC(DotInferShape) { + if(InferShapeDot(op, "input_x", "input_y", "output")) { + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + + +IMPLEMT_VERIFIER(Dot, DotVerify) { + if (op.GetInputDesc("input_x").GetDataType() != op.GetInputDesc("input_y").GetDataType()) { + OP_LOGE("The dataType of input_x and input_y not match."); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +COMMON_INFER_FUNC_REG(Dot, DotInferShape); +VERIFY_FUNC_REG(Dot, DotVerify); +// ---------------Dot END------------------- + +// ---------------IsClose Begin----------------- +IMPLEMT_VERIFIER(IsClose, IsCloseVerify) +{ + if (!CheckTwoInputDtypeSame(op, "x1", "x2")) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +IMPLEMT_COMMON_INFERFUNC(IsCloseInferShape) +{ + Format input_format = op.GetInputDesc("x1").GetFormat(); + Shape x1_shape = op.GetInputDesc("x1").GetShape(); + TensorDesc td = op.GetOutputDesc("y"); + td.SetShape(ge::Shape(x1_shape)); + td.SetDataType(DT_BOOL); + td.SetFormat(input_format); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} +COMMON_INFER_FUNC_REG(IsClose, IsCloseInferShape); +VERIFY_FUNC_REG(IsClose, IsCloseVerify); +// ---------------IsClose END----------------- + +// ----------------ArgMaxGrad-------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMaxGradInferShape) { + Shape shape = op.GetInputDesc("var").GetShape(); + DataType input_dtype = op.GetInputDesc("var").GetDataType(); + Format input_format = op.GetInputDesc("var").GetFormat(); + TensorDesc td = op.GetOutputDesc("y"); + + td.SetShape(shape); + td.SetDataType(input_dtype); + td.SetFormat(input_format); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} + +bool IsArgMaxGradCheckPass(Operator& op, + const string& var_name, + const string& indices_name, + const string& updates_name, + const string& dimmension_name) { + TensorDesc input_var_desc = op.GetInputDesc(var_name); + TensorDesc input_indices_desc = op.GetInputDesc(indices_name); + TensorDesc input_updates_desc = op.GetInputDesc(updates_name); + + ge::Shape shape_indices = input_indices_desc.GetShape(); + ge::Shape shape_updates = input_updates_desc.GetShape(); + ge::Shape shape_var = input_var_desc.GetShape(); + + std::vector shape_indices_list = shape_indices.GetDims(); + std::vector shape_updates_list = shape_updates.GetDims(); + std::vector shape_var_list = shape_var.GetDims(); + + auto dim = 0; + if (op.GetAttr(dimmension_name, dim) == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "get attr dimension failed"); + return false; + } + + int32_t max_shape_len = shape_var.GetDimNum(); + int32_t dims = dim; + if (dims < 0) { + if (dims < (0 - max_shape_len)) { + OP_LOGE(op.GetName().c_str(), "attr dimension invalid.should bigger than -max_shape_len"); + return false; + } + dims = dims + max_shape_len; + } else if (dims >= max_shape_len) { + OP_LOGE(op.GetName().c_str(), "attr dimension invalid. should less than max_shape_len"); + return false; + } + + if ((shape_var_list.size() > 1) && + (shape_var_list.size() != shape_updates_list.size() + 1)) { + OP_LOGE("The dim size of var should biger than updates(indices) 1."); + return false; + } + + if ((1 == shape_var_list.size()) && (1 != shape_updates_list.size())) { + OP_LOGE("The dim size of var should equal updates(indices) when size=1."); + return false; + } + + if (shape_indices_list.size() != shape_updates_list.size()) { + OP_LOGE("The dim size of indices and updates not match."); + return false; + } + + for (size_t i = 0; i < shape_indices.GetDimNum(); i++) { + if (shape_indices.GetDim(i) != shape_updates.GetDim(i)) { + OP_LOGE("The dim value of indices and updates not match."); + return false; + } + } + + if (shape_var_list.size() > 1) { + for (size_t i = 0; i < shape_indices.GetDimNum(); i++) { + if (((i < dims) && (shape_indices.GetDim(i) != shape_var.GetDim(i))) || + ((i >= dims) && (shape_indices.GetDim(i) != shape_var.GetDim(i + 1)))) { + OP_LOGE("The dim value of var and updates not match."); + return false; + } + } + } + + DataType var_dtype = input_var_desc.GetDataType(); + DataType updates_dtype = input_updates_desc.GetDataType(); + if (var_dtype != updates_dtype) { + OP_LOGE("The dtype of var and updates not match."); + return false; + } + + return true; +} + +IMPLEMT_VERIFIER(ArgMaxGrad, ArgMaxGradVerify) { + if (true != IsArgMaxGradCheckPass(op, "var", "indices", "updates", "dimension")) { + OP_LOGE(op.GetName().c_str(),"the ArgMaxGrad op inputs check fail!\n"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ArgMaxGrad, ArgMaxGradInferShape); +VERIFY_FUNC_REG(ArgMaxGrad, ArgMaxGradVerify); +// ------------------ArgMaxGrad END--------------------- + +// ----------------ArgMaxGradD-------------------- +IMPLEMT_COMMON_INFERFUNC(ArgMaxGradDInferShape) { + Shape shape = op.GetInputDesc("var").GetShape(); + DataType input_dtype = op.GetInputDesc("var").GetDataType(); + Format input_format = op.GetInputDesc("var").GetFormat(); + TensorDesc td = op.GetOutputDesc("y"); + + td.SetShape(shape); + td.SetDataType(input_dtype); + td.SetFormat(input_format); + (void)op.UpdateOutputDesc("y", td); + return GRAPH_SUCCESS; +} + +bool IsArgMaxGradDCheckPass(Operator& op, + const string& var_name, + const string& indices_name, + const string& updates_name, + const string& dimmension_name, + const string& assist_name) { + TensorDesc input_var_desc = op.GetInputDesc(var_name); + TensorDesc input_indices_desc = op.GetInputDesc(indices_name); + TensorDesc input_updates_desc = op.GetInputDesc(updates_name); + TensorDesc input_assist_desc = op.GetInputDesc(assist_name); + + ge::Shape shape_indices = input_indices_desc.GetShape(); + ge::Shape shape_updates = input_updates_desc.GetShape(); + ge::Shape shape_var = input_var_desc.GetShape(); + ge::Shape shape_assist = input_assist_desc.GetShape(); + + std::vector shape_indices_list = shape_indices.GetDims(); + std::vector shape_updates_list = shape_updates.GetDims(); + std::vector shape_var_list = shape_var.GetDims(); + std::vector shape_assist_list = shape_assist.GetDims(); + + if (shape_var_list.size() != shape_assist_list.size()) { + OP_LOGE(op.GetName().c_str(), "shape of var and assist mot match."); + return false; + } + + auto dim = 0; + if (op.GetAttr(dimmension_name, dim) == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "get attr dimension failed"); + return false; + } + + int32_t max_shape_len = shape_var.GetDimNum(); + int32_t dims = dim; + if (dims < 0) { + if (dims < (0 - max_shape_len)) { + OP_LOGE(op.GetName().c_str(), "attr dimension invalid.should bigger than -max_shape_len"); + return false; + } + dims = dims + max_shape_len; + } else if (dims >= max_shape_len) { + OP_LOGE(op.GetName().c_str(), "attr dimension invalid. should less than max_shape_len"); + return false; + } + + if ((shape_var_list.size() > 1) && + (shape_var_list.size() != shape_updates_list.size() + 1)) { + OP_LOGE("The dim size of var should biger than updates(indices) 1."); + return false; + } + + if ((1 == shape_var_list.size()) && (1 != shape_updates_list.size())) { + OP_LOGE("The dim size of var should equal updates(indices) when size=1."); + return false; + } + + if (shape_indices_list.size() != shape_updates_list.size()) { + OP_LOGE("The dim size of indices and updates not match."); + return false; + } + + for (size_t i = 0; i < shape_indices.GetDimNum(); i++) { + if (shape_indices.GetDim(i) != shape_updates.GetDim(i)) { + OP_LOGE("The dim value of indices and updates not match."); + return false; + } + } + + for (size_t i = 0; i < shape_var.GetDimNum(); i++) { + if (shape_var.GetDim(i) != shape_assist.GetDim(i)) { + OP_LOGE("The dim value of var and assist not match."); + return false; + } + } + + if (shape_var_list.size() > 1) { + for (size_t i = 0; i < shape_indices.GetDimNum(); i++) { + if (((i < dims) && (shape_indices.GetDim(i) != shape_var.GetDim(i))) || + ((i >= dims) && (shape_indices.GetDim(i) != shape_var.GetDim(i + 1)))) { + OP_LOGE("The dim value of var and updates not match."); + return false; + } + } + } + + DataType var_dtype = input_var_desc.GetDataType(); + DataType updates_dtype = input_updates_desc.GetDataType(); + if (var_dtype != updates_dtype) { + OP_LOGE("The dtype of var and updates not match."); + return false; + } + + return true; +} + +IMPLEMT_VERIFIER(ArgMaxGradD, ArgMaxGradDVerify) { + if (true != IsArgMaxGradDCheckPass(op, "var", "indices", "updates", "dimension", "assist")) { + OP_LOGE(op.GetName().c_str(),"the ArgMaxGradD op inputs check fail!\n"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +INFER_FUNC_REG(ArgMaxGradD, ArgMaxGradDInferShape); +VERIFY_FUNC_REG(ArgMaxGradD, ArgMaxGradDVerify); +// ------------------ArgMaxGradD END--------------------- + +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/elewise_calculation_ops.h b/tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file mode 100644 index 00000000..ad1b71a4 --- /dev/null +++ b/tests/st/framework/stub_op_proto/elewise_calculation_ops.h @@ -0,0 +1,3788 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file elewise_calculation_ops.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_INC_ELEWISE_CALCULATION_OPS_H_ +#define OPS_BUILT_IN_OP_PROTO_INC_ELEWISE_CALCULATION_OPS_H_ +#include "graph/operator_reg.h" + +namespace ge { +/** +*@brief Adds all input tensors element-wise. \n + +*@par Inputs: +*Dynamic inputs, including: +* @li x: A list of Tensor objects, each with same shape and type. The supported types are: +* float16, float32, double, int32, uint8, int16, int8, complex64, int64, +* qint8, quint8, qint32, uint16, complex128, uint32, uint64. It's a dynamic input. \n + +*@par Attributes: +*N: An required attribute of type int32, means nums of inputs. \n + +*@par Outputs: +*y: A Tensor. Has the same shape and type as the elements of "x". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AddN. +*/ +REG_OP(AddN) + .DYNAMIC_INPUT(x, TensorType::NumberType()) + .OUTPUT(y, TensorType::NumberType()) + .REQUIRED_ATTR(N, Int) + .OP_END_FACTORY_REG(AddN) + +/** +*@brief Calculates the reversed outputs of the function "maximum" + +*@par Inputs: +*Three inputs, including: +* @li grads: A mutable Tensor. Must be one of the following types: +* float16, float32, int32. +* @li x1: A mutable Tensor of the same type as "grads". +* @li x2: A mutable Tensor of the same type as "grads". \n + +*@par Attributes: +*@li grad_x: An optional bool. Defaults to "True". +* If "True", "y1" will be output. +* If "False", "y1" will not be output. \n + +*@li grad_y: An optional bool. Defaults to "True". +* If "True", "y2" will be output. +* If "False", "y2" will not be output. \n + +*@par Outputs: +* @li y1: A mutable Tensor. Has the same type as "grads". +* @li y2: A mutable Tensor. Has the same type as "grads". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaximumGrad. +*/ +REG_OP(MaximumGrad) + .INPUT(grads, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .ATTR(grad_x, Bool, true) + .ATTR(grad_y, Bool, true) + .OP_END_FACTORY_REG(MaximumGrad) + +/** +*@brief Calculates the reversed outputs of the function "minimum" + +*@par Inputs: +*Three inputs, including: +* @li grads: A mutable Tensor. Must be one of the following types: +* float16, float32, int32. +* @li x1: A mutable Tensor of the same type as "grads". +* @li x2: A mutable Tensor of the same type as "grads". \n + +*@par Attributes: +*@li grad_x: An optional bool. Defaults to "True". +* If "True", "y1" will be output. +* If "False", "y1" will not be output. \n + +*@li grad_y: An optional bool. Defaults to "True". +* If "True", "y2" will be output. +* If "False", "y2" will not be output. \n + +*@par Outputs: +* @li y1: A mutable Tensor. Has the same type as "grads". +* @li y2: A mutable Tensor. Has the same type as "grads". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MinimumGrad. +*/ +REG_OP(MinimumGrad) + .INPUT(grads, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .ATTR(grad_x, Bool, true) + .ATTR(grad_y, Bool, true) + .OP_END_FACTORY_REG(MinimumGrad) + +/** +*@brief Cast a tensor form src data type to dst data type. \n + +*@par Inputs: +*One input: +*x:A Tensor. Must be one of the following types: bool, float16, float, int8, int32, uint32, uint8, + int64, uint64, int16, uint16, double, complex64, complex128, qint8, quint8, qint16, quint16, qint32. + For float32 type, the actual calculation on the chip is based on float16. \n + +*@par Attributes: +*dst_type: An required attribute of type int32, specifying the dst data type. \n + +*@par Outputs: +*y:A Tensor. Has the same type as x. +*/ +REG_OP(Cast) + .INPUT(x, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, + DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32})) + .OUTPUT(y, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, + DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32})) + .REQUIRED_ATTR(dst_type, Int) + .OP_END_FACTORY_REG(Cast) + +/** +*@brief Returns the truth value of (x1 >= x2) element-wise. \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, +* double, int32, int8, uint8, int64, uint16, uint32, uint64. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the TensorFlow operator GreaterEqual. +*/ +REG_OP(GreaterEqual) + .INPUT(x1, TensorType::RealNumberType()) + .INPUT(x2, TensorType::RealNumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(GreaterEqual) + +/** +*@brief Returns the truth value of (x1 < x2) element-wise. \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, double, int32, +* uint8, int16, int8, int64, uint16, uint32, uint64. +* @li x2: A Tensor with the same type as "x1". \n + +*@par Outputs: +*y: A Tensor of type bool. \n + +*@par Third-party framework compatibility: +* Compatible with TensorFlow operator Less. +*/ +REG_OP(Less) + .INPUT(x1, TensorType::RealNumberType()) + .INPUT(x2, TensorType::RealNumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(Less) + +/** +*@brief Returns x1/x2 element-wise for real types. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: float16, float32, double, uint16, + int8, uint8, int16, int32, int64, complex64, DT_COMPLEX128. +*@li x2: A Tensor. Must be one of the following types: float16, float32, double, uint16, + int8, uint8, int16, int32, int64, complex64, DT_COMPLEX128. \n + +*@par Outputs: +* y: A Tensor. Has the same type and format as input "x1". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator RealDiv. +*/ +REG_OP(RealDiv) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(RealDiv) + +/** +*@brief Computes square root of x element-wise. \n + +*@par Inputs: +* x: A Tensor. Must be one of the following types: float16, float32, complex128, complex64, float64. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Sqrt. +*/ +REG_OP(Sqrt) + .INPUT(x, TensorType{(DT_FLOAT. DT_FLOAT16, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128)}) + .OUTPUT(y, TensorType{(DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128)}) + .OP_END_FACTORY_REG(Sqrt) + +/** +*@brief Returns the max of "x" and "y" (i.e. x > y ? x: y) element-wise. \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, double, int32, int64. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Maximum. +*/ +REG_OP(Maximum) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64})) + .OP_END_FACTORY_REG(Maximum) + +/** +*@brief Returns the min of x and y (i.e. x1 < x2 ? x1 : x2) element-wise. \n + +*@par Inputs: +*Two inputs, include: +* @li x1: A Tensor. Must be one of the following types: float32, float16, double, int32, int64. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor of the same type as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the TensorFlow operator Minimum. +*/ +REG_OP(Minimum) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT32, + DT_INT64})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT32, + DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_INT32, + DT_INT64})) + .OP_END_FACTORY_REG(Minimum) + +/** +*@brief: Computes the reciprocal of "x". \n + +*@par Inputs: +*One inputs, include: +*x:A Tensor of type float16, float32, int32, int64, double, +* complex64, complex128.the format can be [NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND] + +*@par Outputs: +*y:A Tensor with same type as "x". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Reciprocal. +*/ +REG_OP(Reciprocal) + .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_FLOAT16, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_FLOAT16 + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Reciprocal) + +/** +*@brief Returns x - y element-wise. +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: int8, int16, int32, int64, uint8, float64, +* float16, float32, complex128, complex64, uint16. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Subtract. +*/ +REG_OP(Sub) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_UINT8, DT_INT8, + DT_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Sub) + +/** +*@brief computes the absolute value of a tensor. \n + +*@par Inputs: +*One inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, double, int32, int64. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Abs. +*/ +REG_OP(Abs) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(Abs) + +/** +*@brief Computes gradients for absolute operation. \n + +* +*@par Inputs: +*@li y: A tensor of type float16 or float32. +*@li dy: A tensor of the same type as "y". +* +*@attention Constraints: +* "dy" has the same type as "y". +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AbsGrad. +* +*/ +REG_OP(AbsGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(AbsGrad) + +/** +*@brief: Computes the sign of "x". \n + +*@par Inputs: +*x:An ND Tensor of type float16, float32, int32, int64, double, +* complex64, complex128. \n + +*@par Outputs: +*y:An ND Tensor with same type as "x". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Sign. +*/ +REG_OP(Sign) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Sign) + +/** +*@brief Returns (x1 - x2)(x1 - x2) element-wise. \n + +*@par Inputs: +*Two inputs, including: \n +*@li x1: A Tensor. Must be one of the following types: float16, float32, float64, int32, int64, complex64,complex128 +*@li x2: A Tensor. Has the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator SquaredDifference. +*/ +REG_OP(SquaredDifference) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, + DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(SquaredDifference) + +/** +*@brief Computes cosine of "x" element-wise. \n + +*@par Inputs: +*x: A Tensor of type float16, float32, double, complex64, complex128. +* the format can be [NCHW,NC1HWC0,NHWC,ND] + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Cos. \n + +*/ +REG_OP(Cos) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Cos) + +/** +*@brief Returns x1/x2 element-wise. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: +* float16, float32, int32, int8, uint8, float64, int64, uint16, int16, +* complex64, complex128, the format can be [NCHW,NC1HWC0,NHWC,ND]. +*@li x2: A Tensor. Has the same type and format as input "x1". \n + +*@par Outputs: +* y: A Tensor. Has the same type and format as input "x1". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Div. +*/ +REG_OP(Div) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_INT64, DT_UINT16, DT_INT16, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_INT64, DT_UINT16, DT_INT16, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_INT64, DT_UINT16, DT_INT16, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Div) + +/** +*@brief: Returns the truth value of (x = y) element-wise. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: +* float16, float32, int32, int8, uint8, double, int16, int64, complex64, +* complex128, quint8, qint8, qint32, string, bool. the format can be +* [NCHW, NC1HWC0, NHWC, ND] +*@li x2: A Tensor of the same type and format as "x1". \n + +*@par Outputs: +*y: A Tensor of type bool. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Equal. +*/ +REG_OP(Equal) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_UINT8, + DT_DOUBLE, DT_INT16, DT_INT64, DT_COMPLEX64, + DT_COMPLEX128, DT_QUINT8, DT_QINT8, DT_QINT32, + DT_STRING, DT_BOOL})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_UINT8, + DT_DOUBLE, DT_INT16, DT_INT64, DT_COMPLEX64, + DT_COMPLEX128, DT_QUINT8, DT_QINT8, DT_QINT32, + DT_STRING, DT_BOOL})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(Equal) + +/** +*@brief Computes the exponential of "x" element-wise. \n + +*@par Inputs: +*One input:\n +*x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128. \n + +*@par Attributes: +*@li base: An optional attribute of type float32, specifying the base gamma. Defaults to "-1.0". +*@li scale: An optional attribute of type float32, specifying the scale alpha. Defaults to "1.0". +*@li shift: An optional attribute of type float32, specifying the shift beta. Defaults to "0.0". \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Exp. +*/ +REG_OP(Exp) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .ATTR(base, Float, -1.0) + .ATTR(scale, Float, 1.0) + .ATTR(shift, Float, 0.0) + .OP_END_FACTORY_REG(Exp) + +/** +*@brief Computes the exp(x) - 1 element-wise, y = e^x - 1. \n + +*@par Inputs: +*One input: +*x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128. \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Expm1. +*/ +REG_OP(Expm1) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Expm1) + +/** +*@brief: Computes the reciprocal of "x". \n + +*@par Inputs:\n +*x: A Tensor. Must be one of the following types: float16, float32, int32, int64, double, complex64, complex128. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Inv. +*/ +REG_OP(Inv) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_COMPLEX64,DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32})) + .OP_END_FACTORY_REG(Inv) + +/** +*@brief: Computes "x" reciprocal grad, dx = -1*dy*y*y, where, "y = 1/x", and "dy" + is the corresponding input gradient. \n + +*@par Inputs: +* Two inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8. +* @li grad: A Tensor. Has the same type as "x". \n + +*@par Outputs: +*y: A Tensor, Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator InvGrad. +*/ +REG_OP(InvGrad) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(grad, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .OP_END_FACTORY_REG(InvGrad) + +/** +*@brief: Returns the truth value of (x <= y) element-wise. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: float32, float64, +* int32, uint8, int16, int8, int64, qint8, quint8, qint32, uint16, +* float16, uint32, uint64. +*@li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor of type bool. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator LessEqual. +*/ +REG_OP(LessEqual) + .INPUT(x1, TensorType::RealNumberType()) + .INPUT(x2, TensorType::RealNumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(LessEqual) + +/** +*@brief Computes the logarithm of (x + 1) element-wise, y = ln(x + 1). \n + +*@par Inputs: +*One input:\n +*x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128. \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Log1p. +*/ +REG_OP(Log1p) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Log1p) + +/** +*@brief Returns element-wise remainder of division. +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, + * int32, int64, int8, uint8, double. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". + +*@attention Constraints: +*@li x2: The input data does not support 0 +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*requirement of double thousandths in the mini form +*@li Due to different architectures, the calculation results of this operator +*on NPU and CPU may be inconsistent +*@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Mod. +*/ +REG_OP(Mod) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, + DT_INT64, DT_DOUBLE})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, + DT_INT64, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, + DT_INT64, DT_DOUBLE})) + .OP_END_FACTORY_REG(Mod) + +/** +*@brief: Returns the truth value of (x != y) element-wise. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: float16, float32, int32, + * int8, uint8, double, int16, int64, uint16, half, uint32, uint64 +*@li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor of type bool. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator NotEqual. +*/ +REG_OP(NotEqual) + .INPUT(x1, TensorType::RealNumberType()) + .INPUT(x2, TensorType::RealNumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(NotEqual) + +/** +*@brief Computes numerical negative value element-wise (y = -x) + +*@par Inputs: +* One input: +*x: A Tensor. Must be one of the following types: float16, float32, int32, + * int64, complex64, complex128. \n + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Neg. +*/ +REG_OP(Neg) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Neg) + +/** +*@brief Returns x1/x2 element-wise for integer types. \n + +*@par Inputs: +*@li x1: A Tensor. Must be one of the following types: +* float32, float64, int32, uint8, int16, int8, +* complex64, int64, qint8, quint8, qint32, uint16, +* complex128, float16, uint32, uint64, complex64, complex128. +*@li x2: A Tensor of the same data type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". + +*@attention Constraints: +* Broadcasting is supported. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator TruncateDiv. \n + +*/ +REG_OP(TruncateDiv) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_UINT16, DT_INT16, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_UINT16, DT_INT16, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT32, + DT_DOUBLE, DT_UINT16, DT_INT16, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(TruncateDiv) + +/** +*@brief Computes x1/x2 element-wise, if x1 == 0, return 0. + +*@par Inputs: +* Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, +* double, complex64, complex128. +* @li x2: A Tensor. Has the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Xdivy. +*/ +REG_OP(Xdivy) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OP_END_FACTORY_REG(Xdivy) + +/** +*@brief Computes "x" multiplied by the logarithm of y element-wise, +* if "x" == 0, return "0". \n + +*@par Inputs: +* Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, +* double, complex64, complex128. +* @li x2: A Tensor. Has the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Xlogy. +*/ +REG_OP(Xlogy) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OP_END_FACTORY_REG(Xlogy) + +/** +*@brief Computes square of "x" element-wise. \n + +*@par Inputs: +*One input: \n +*x: A Tensor. Must be one of the following types: float16, float32, float64, int32, int64, complex64, complex128 + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Square. +*/ +REG_OP(Square) + .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT16, DT_FLOAT, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT16, DT_FLOAT, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Square) + + +/** +*@brief Computes reciprocal of square root of "x" element-wise: y = 1/sqrt{x}. \n + +* +*@par Inputs: +* x: An ND or 5HD tensor. Must be one of the following types: float, double, half, + * complex64, complex128. +* +*@par Outputs: +* y: An ND or 5HD tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Rsqrt. +* +*/ +REG_OP(Rsqrt) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Rsqrt) + +/** +*@brief Computes the trignometric inverse sine of "x" element-wise. \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, int32, int64, complex64, complex128. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Asin. +* +*/ +REG_OP(Asin) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Asin) + +/** +*@brief Computes gradients for Asin operation. \n + +* +*@par Inputs: +*@li y: A tensor of type float16, float32, float64, int32, int64, complex64, complex128. +*@li dy: A tensor of the same type as "y". +* +*@attention Constraints: +* "dy" has the same type as "y". +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AsinGrad. +* +*/ +REG_OP(AsinGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(AsinGrad) + +/** +*@brief Computes acos of x element-wise. \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, int32, int64, complex64, complex128. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Acos. +* +*/ +REG_OP(Acos) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, + DT_INT32, DT_INT64, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Acos) + +/** +*@brief Computes gradients for Acos operation. \n + +* +*@par Inputs: +*@li y: A tensor of type float16 or float32. +*@li dy: A tensor of the same type as "y". +* +*@attention Constraints: +* "dy" has the same shape as "y". +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AcosGrad. +* +*/ +REG_OP(AcosGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(AcosGrad) + +/** +*@brief Computes inverse hyperbolic cosine of x element-wise. \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, complex64, complex128. +* +*@attention Constraints: +* x Given an input tensor, the function computes inverse hyperbolic cosine of every element.\n +* Input range is [1, inf]. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Acosh. +* +*/ +REG_OP(Acosh) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Acosh) + +/** +*@brief Computes gradients for Acosh operation. \n + +* +*@par Inputs: +*@li y: A tensor of type float16 or float32. +*@li dy: A tensor of the same type as "y". +* +*@attention Constraints: +* "dy" has the same type as "y". +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AcoshGrad. +* +*/ +REG_OP(AcoshGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(AcoshGrad) + +/** +*@brief Returns the truth value of x1 OR x2 element-wise. \n + +* +*@par Inputs: +*@li x1: A tensor of type bool. +*@li x2: A tensor of the same type as "x1". +* +*@attention Constraints: +* LogicalOr supports broadcasting. +* +*@par Outputs: +* y: A tensor of the same type as "x1". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator LogicalOr. +* +*/ +REG_OP(LogicalOr) + .INPUT(x1, TensorType({DT_BOOL})) + .INPUT(x2, TensorType({DT_BOOL})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(LogicalOr) + +/** +*@brief Returns the truth value of x1 AND x2 element-wise. \n + +* +*@par Inputs: +*@li x1: A tensor of type bool. +*@li x2: A tensor of the same type as "x1". +* +*@attention Constraints: +* LogicalAnd supports broadcasting. +* +*@par Outputs: +* y: A tensor of the same type as "x1". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator LogicalAnd. +* +*/ +REG_OP(LogicalAnd) + .INPUT(x1, TensorType({DT_BOOL})) + .INPUT(x2, TensorType({DT_BOOL})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(LogicalAnd) + +/** +*@brief Computes the Bessel i0e function of "x" element-wise. +* Exponentially scaled modified Bessel function of order 0 +* defined as: bessel_i0e(x) = exp(-abs(x)) bessel_i0(x). +* This function is faster and numerically stabler than "bessel_i0(x)". +* +*@par Inputs: +* x: A tensor of type float16, float32, or float64. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BesselI0e. +* +*/ +REG_OP(BesselI0e) + .INPUT(x, TensorType::FloatingDataType()) + .OUTPUT(y, TensorType::FloatingDataType()) + .OP_END_FACTORY_REG(BesselI0e) + +/** +*@brief Computes the Bessel i1e function of "x" element-wise. +* Exponentially scaled modified Bessel function of order 0 +* defined as: bessel_i1e(x) = exp(-abs(x)) bessel_i1(x). +* This function is faster and numerically stabler than "bessel_i1(x)". +* +*@par Inputs: +* x: A tensor of type float16, float32, or float64. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BesselI1e. +* +*/ +REG_OP(BesselI1e) + .INPUT(x, TensorType::FloatingDataType()) + .OUTPUT(y, TensorType::FloatingDataType()) + .OP_END_FACTORY_REG(BesselI1e) + +/** +* @brief Computes logarithm of x element-wise. +* y = log_base(shift + scale * x), with "base" > 0. \n + +* @par Inputs: +* @li x: A Tensor of type complex64, complex128, float16, float32 or double. \n + +* @par Attributes: +* @li base: An optional float32, specifying the base "e". Defaults to "-1.0" + +* @li scale: An optional float32, specifying the scale of input "x". Defaults +* to "1.0" +* @li shift: An optional float32, specifying the shift. Defaults to "0.0" + +* @par Outputs: +* y: A Tensor has same type as "x". \n + +* @attention Constraints: +* @li "base" is supposed to be greater than 0. Retaining the default +* value "-1" sets "base" to "e". +* @li If the input value of operator Log is within the range (0, 0.01] or +* [0.95, 1.05], the output accuracy is subject to change. \n + +* @par Third-party framework compatibility +* @li Compatible with the TensorFlow operator Log. +* @li Compatible with the Caffe operator Log. +*/ +REG_OP(Log) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .ATTR(base, Float, -1.0) + .ATTR(scale, Float, 1.0) + .ATTR(shift, Float, 0.0) + .OP_END_FACTORY_REG(Log) + +/** +* @brief Returns x1 * x2 element-wise. +* y = x1 * x2 + +* @par Inputs: +* @li x1: A Tensor. Must be one of the following types: float16, float32, +* float64, uint8, int8, uint16, int16, int32, int64, complex64, complex128. +* @li x2: A Tensor. Must be one of the following types: float16, float32, +* float64, uint8, int8, uint16, int16, int32, int64, complex64, complex128. \n + +* @par Outputs: +* y: A Tensor. Must be one of the following types: float16, float32, float64, +* uint8, int8, uint16, int16, int32, int64, complex64, complex128. \n + +* @attention Constraints: +* @li "x1" and "x2" have incompatible shapes or types. \n + +* @par Third-party framework compatibility +* Compatible with the TensorFlow operator Multiply. +*/ +REG_OP(Mul) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, + DI_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, + DI_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, + DI_UINT16, DT_INT16, DT_INT32, DT_INT64, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Mul) + +/** +* @brief Computes the gradient of the square root of "x" with regard to its +* input. grad = dy * 0.5/y, where y = sqrt(x), and "dy" is the corresponding +* input gradient. \n + +* @par Inputs: +* Two inputs, including: +* @li y: A Tensor of type float32 or float16. +* @li dy: A Tensor. Has the same type as "y". \n + +* @par Outputs: +* z: A Tensor. Has the same type as "y". \n + +* @attention Constraints: +* "dy" has the same shape and type as "y". +*/ +REG_OP(SqrtGrad) + .INPUT(y, TensorType(UnaryDataType)) + .INPUT(dy, TensorType(UnaryDataType)) + .OUTPUT(z, TensorType(UnaryDataType)) + .OP_END_FACTORY_REG(SqrtGrad) + +/** +*@brief Returns x + y element-wise. +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: int8, int16, int32, int64, uint8, float64, +* float16, float32, complex128, complex64, string. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Add. +*/ +REG_OP(Add) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OP_END_FACTORY_REG(Add) + +/** +*@brief Confuse broadcast, add and mul. \n + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32 float16, float32. +* @li x2: A Tensor of the same type as "x1". +* @li x3: A Tensor of the same type as "x1". \n + +*@par Outputs: +*@li y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the TensorFlow operator LRN. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ + +REG_OP(FusedMulAdd) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x3, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OP_END_FACTORY_REG(FusedMulAdd) + +/** +*@brief Returns x1 + x2 element-wise. \n + +* +*@par Inputs: +*@li x1: A tensor. Must be one of the following types: float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128. +*@li x2: A tensor of the same type as "x1". +* +*@attention Constraints: +* AddV2 supports broadcasting. +* +*@par Outputs: +* y: A tensor. Has the same type as "x1". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AddV2. +* +*/ +REG_OP(AddV2) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128})) + .OP_END_FACTORY_REG(AddV2) + +/** +*@brief Updates "ref" by adding "value" to it. \n + +*@par Inputs: +*@li ref: A Tensor. Must be one of the following types: float16, float32, int8, int16, int32, int64, uint8, uint16, uint32, uint64. +*@li value: A Tensor of the same type as "ref". \n + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". + If "True", the addition will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +* This attribute is reserved. \n + +*@par Outputs: +*ref: A Tensor that holds the new value of ref after the value has been added. \n + +*@attention Constraints: +*An input tensor of type int64 must have a shape with size 1. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AssignAdd. +*/ +REG_OP(AssignAdd) + .INPUT(ref, TensorType::BasicType()) + .INPUT(value,TensorType::BasicType()) + .OUTPUT(ref, TensorType::BasicType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(AssignAdd) + +/** +*@brief Updates "ref" by assigning "value" to it. \n + +*@par Inputs: +*@li ref: A Tensor. Must be one of the following types: float16, float32, int8, int16, int32, int64, uint8, uint16, uint32, uint64. +*@li value: A Tensor of the same type as "ref". \n + +*@par Attributes: +*@li validate_shape: An optional bool. Defaults to "true". + If "true", the operation will validate that the shape of "value" matches the shape of the Tensor being assigned to. +* If "false", "ref" will take on the shape of "value". +* This attribute is reserved. +*@li use_locking: An optional bool. Defaults to True. + If True, the assignment will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +* This attribute is reserved. \n + +*@par Outputs: +*ref: A Tensor that holds the new value of ref after the value has been assigned. \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Assign. +*/ +REG_OP(Assign) + .INPUT(ref, TensorType::BasicType()) + .INPUT(value,TensorType::BasicType()) + .OUTPUT(ref, TensorType::BasicType()) + .ATTR(validate_shape, Bool, true) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(Assign) + +/** +*@brief Updates "var" by subtracting "value" from it.\n +* This operation outputs "var" after the update is done. \n +* This makes it easier to chain operations that need to use the reset value. \n + +* +*@par Inputs: +*@li var: A tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, uint16, complex128, uint32, uint64 +*@li value: A tensor of the same type as "var". +* +*@par Attributes: +* use_locking: An optional bool. Defaults to "False". If "True", the subtraction will be protected \n +* by a lock; otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +* y: A tensor. Has the same type as "var". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AssignSub. +* +*/ +REG_OP(AssignSub) + .INPUT(var, TensorType::NumberType()) + .INPUT(value,TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(AssignSub) + +/** +*@brief: Computes the backpropagation of the square root operation. \n + +*@par Inputs: +* Two inputs, including: +*@li y: An NCHW, NC1HWC0, NHWC, ND Tensor. Must be one of the following types: \ + * float, int32, int8, double, complex64, complex128, half. +*@li dy: A Tensor of the same type and format as "y". \n + +*@par Outputs: +*z: A Tensor of the same type and format as "y". \n + +*@see Matmul() | Rsqrt () + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator RsqrtGrad. +*/ +REG_OP(RsqrtGrad) + .INPUT(y, TensorType({UnaryDataType,int32,int8})) + .INPUT(dy, TensorType({UnaryDataType,int32,int8})) + .OUTPUT(z, TensorType({UnaryDataType,int32,int8})) + .OP_END_FACTORY_REG(RsqrtGrad) + +/** +*@brief Computes hyperbolic sine of "x" element-wise. \n + +*@par Inputs: +*x: An NCHW, NC1HWC0, NHWC,or ND Tensor of type float, double, complex64, + * complex128, half. \n + +*@par Outputs: +*y: A NCHW, NC1HWC0, NHWC,or ND Tensor of type float, double, complex64, + * complex128, half. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Sinh. \n + +*/ +REG_OP(Sinh) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Sinh) + +/** +*@brief: Clips tensor values to a specified min and max. \n + +*@par Inputs: +* Three inputs, including: +*@li x: A Tensor of type float32, float64, int32, uint8, int16, int8, complex64, int64, +*qint8, quint8, qint32, uint16, complex128, float16, uint32, uint64. +*@li clip_value_min: A Tensor of the same type as "x". +*@li clip_value_max: A Tensor of the same type as "x". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator ClipByValue. +*/ +REG_OP(ClipByValue) + .INPUT(x, TensorType::NumberType()) + .INPUT(clip_value_min, TensorType::NumberType()) + .INPUT(clip_value_max, TensorType::NumberType()) + .OUTPUT(y, TensorType::NumberType()) + .OP_END_FACTORY_REG(ClipByValue) + +/** +*@brief Computes cosine of "x" element-wise. \n + +*@par Inputs: +*x: A Tensor of type float16, float32, double, complex64, complex128. +* the format can be [NCHW,NC1HWC0,NHWC,ND]. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Cosh. \n + +*/ +REG_OP(Cosh) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Cosh) + +/** +*@brief: Returns 0 if the denominator is zero, else, like Div. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types:float16, float32, int32, +* int8, uint8, double, the format can be [NCHW,NC1HWC0,NHWC,ND]. +*@li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator DivNoNan. +*/ +REG_OP(DivNoNan) + .INPUT(x1, TensorType({DT_FLOAT, DT_UINT8, DT_INT8, DT_INT32, DT_FLOAT16, + DT_DOUBLE})) + .INPUT(x2, TensorType({DT_FLOAT, DT_UINT8, DT_INT8, DT_INT32, DT_FLOAT16, + DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_UINT8, DT_INT8, DT_INT32, DT_FLOAT16, + DT_DOUBLE})) + .OP_END_FACTORY_REG(DivNoNan) + +/** +*@brief Reverses specific dimensions of a tensor. \n + +*@par Inputs: +* One input: \n +*x: A Tensor, Must be one of the following types: +* int32, uint8, int16, int8, int64, int64, uint16, uint32, uint64, +* and format can be [NCHW,NC1HWC0,NHWC,ND] + +*@par Outputs: +*y: A Tensor. Has the same type and format as "x" + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Invert. +*/ +REG_OP(Invert) + .INPUT(x, TensorType::IntegerDataType()) + .OUTPUT(y, TensorType::IntegerDataType()) + .OP_END_FACTORY_REG(Invert) + +/** +*@brief Returns a tensor of the same shape and type with all elements set to one. +*@par Inputs: +*One input: \n +*x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, + * int16, uint16, int32, int64, complex128, bool. \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator OnesLike. +*/ +REG_OP(OnesLike) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8, + DT_UINT8, DT_INT16, DI_UINT16, DT_INT32, + DT_INT64, DT_COMPLEX128, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8, + DT_UINT8, DT_INT16, DI_UINT16, DT_INT32, + DT_INT64, DT_COMPLEX128, DT_BOOL})) + .OP_END_FACTORY_REG(OnesLike) + +/** +*@brief Computes the gradient for the inverse of "x" with regard its input. \n + +*@par Inputs: +*@li input_y: A Tensor. Must be one of the following types: float, double, + * complex64, complex128, half. +*@li input_dy: A Tensor. Must be one of the following types: float, double, + * complex64, complex128, half. \n + +*@par Outputs: +*output_data: A Tensor. Must be one of the following types: float, double, + * complex64, complex128, half. \n + +*@attention Constraints: +* "input_dy" has the same shape and type as "input_y". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator reciprocal_grad. +*/ +REG_OP(ReciprocalGrad) + .INPUT(y, TensorType::UnaryDataType()) + .INPUT(dy, TensorType::UnaryDataType()) + .OUTPUT(z, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(ReciprocalGrad) + +/** +*@brief Returns the truth value of (x1 > x2) element-wise. \n + +*@par Inputs: +*@li x1: A Tensor of type float16, float32, double, int64, int32, int16, int8, +* uint8, uint16, uint32, uint64. +*@li x2: A Tensor of the same data type as "x1". \n + +*@par Outputs: +*y: A Tensor of type bool. + +*@attention Constraints: +* Broadcasting is supported. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Greater. \n + +*/ +REG_OP(Greater) + .INPUT(x1, TensorType::RealNumberType()) + .INPUT(x2, TensorType::RealNumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(Greater) + +/** +*@brief Returns a tensor of the same type and shape as the input tensor with all elements set to zero. \n + +*@par Inputs: +*x: A Tensor. Must be one of the following types: +* float32, float64, int32, uint8, int16, int8, +* complex64, int64, qint8, quint8, qint32, qint16, quint16, uint16, +* complex128, float16, uint32, uint64, complex64, complex128. \n + +*@par Outputs: +*y: A Tensor of the same data type as "x". \n + +*@attention Constraints: +* The output has the same shape and type as the input. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator zeros_like. +*/ +REG_OP(ZerosLike) + .INPUT(x, TensorType::BasicType()) + .OUTPUT(y, TensorType::BasicType()) + .OP_END_FACTORY_REG(ZerosLike) + +/** +*@brief Returns the truth value of NOT "x" element-wise. \n + +*@par Inputs: +*x: A Tensor of type bool. \n + +*@par Outputs: +*y: A Tensor of type bool. \n + +*@attention Constraints: +* The input and output values are "1" or "0", corresponding to bool values "true" and "false". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator logical_not. +*/ +REG_OP(LogicalNot) + .INPUT(x, TensorType({DT_BOOL})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(LogicalNot) + +/** +*@brief Computes inverse hyperbolic sine of x element-wise. +* Given an input tensor, this function computes inverse hyperbolic sine for every element in the tensor. \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, complex64, complex128. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Asinh. +* +*/ +REG_OP(Asinh) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Asinh) + +/** +*@brief Computes gradients for Asinh operation. \n + +* +*@par Inputs: +*@li y: A tensor. Must be one of the following types: float16, float32. +*@li dy: A tensor of the same type as "y" +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AsinhGrad. +* +*/ +REG_OP(AsinhGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(AsinhGrad) + +/** +*@brief Computes inverse hyperbolic tangent of x element-wise.\n +* Given an input tensor, this function computes inverse hyperbolic tangent for every element in the tensor. \n Input range is [-1,1] and output range is [-inf, inf]. If input is -1, \n output will be -inf and if the input is 1, output will be inf.\n Values outside the range will have nan as output. \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, complex64, complex128. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Atanh. +* +*/ +REG_OP(Atanh) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Atanh) + +/** +*@brief Computes the trignometric inverse tangent of x element-wise. +* The atan operation returns the inverse of tan, such that if y = tan(x) then, x = atan(y). \n + +* +*@par Inputs: +* x: A tensor. Must be one of the following types: float16, float32, float64, complex64, complex128. +* +*@par Outputs: +* y: A tensor. Has the same type as "x". The output of atan will lie within the invertible range of tan, i.e (-pi/2, pi/2). +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Atan. +* +*/ +REG_OP(Atan) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Atan) + +/** +*@brief Computes gradients for Atan operation. \n + +* +*@par Inputs: +*@li y: A tensor of type float16 or float32. +*@li dy: A tensor of the same type as "y" +* +*@par Outputs: +* z: A tensor. Has the same type as "y". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AtanGrad. +* +*/ +REG_OP(AtanGrad) + .INPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(dy, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(z, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(AtanGrad) + +/** +*@brief Computes arctangent of x1/x2 element-wise, respecting signs of the arguments. \n + +* +*@par Inputs: +*@li x1: A tensor. Must be one of the following types: float16, float32, float64 +*@li x2: A tensor of the same type as "x1". +* +*@par Outputs: +* y: A tensor. Has the same type as "x1". +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Atan2. +* +*/ +REG_OP(Atan2) + .INPUT(x1, TensorType::FloatingDataType()) + .INPUT(x2, TensorType::FloatingDataType()) + .OUTPUT(y, TensorType::FloatingDataType()) + .OP_END_FACTORY_REG(Atan2) + +/** +*@brief Returns the truth value of abs(x1-x2) < tolerance element-wise. \n + +* +*@par Inputs: +*@li x1: A tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, uint16, complex128, float16, uint32, uint64 +*@li x2: A tensor of the same type as "x1". +* +*@par Attributes: +* tolerance: Defaults to "1e-05". +* +*@par Outputs: +* y: A tensor of type bool. +* +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator ApproximateEqual. +* +*/ +REG_OP(ApproximateEqual) + .INPUT(x1, TensorType::NumberType()) + .INPUT(x2, TensorType::NumberType()) + .OUTPUT(y, TensorType({DT_BOOL})) + .ATTR(tolerance, Float, 1e-5) + .OP_END_FACTORY_REG(ApproximateEqual) + +/** +*@brief Returns the element-wise sum of a list of tensors.\n +* AccumulateNV2 performs the same operation as AddN, but does not wait for all of its inputs +to be ready before beginning to sum.\n This can save memory if inputs are ready at different times, +since minimum temporary storage is proportional to the output size rather than the inputs size. + Returns a Tensor of same shape and type as the elements of inputs. \n + +* +*@par Inputs: +*Dynamic inputs, including: +* x: A tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, +qint8, quint8, qint32, uint16, complex128, float16, uint32, uint64. It's a dynamic input. \n +* +*@par Outputs: +* y: A tensor. Has the same type as "x". +* +*@par Attributes: +* N: the size of x. +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator AccumulateNV2. +* +*/ +REG_OP(AccumulateNV2) + .DYNAMIC_INPUT(x, TensorType::NumberType()) + .OUTPUT(y, TensorType::NumberType()) + .REQUIRED_ATTR(N, Int) + .OP_END_FACTORY_REG(AccumulateNV2) + +/** +*@brief Fake-quantizes the input Tensor, type float to output a Tensor of same type. +* [min, max] define the clamping range for the "inputs" data.\n +* the values of "x" are quantized into the quantization range ([0, 2^num_bits - 1] \n +* when "narrow_range" is "false" or [1, 2^num_bits - 1] when it is "true") and \n +* then de-quantized and output as float32 in [min; max] interval.\n +* num_bits is the bit width of the quantization, between 2 and 16, inclusive. \n +* Quantization is called fake since the output is still in floating point. \n + +*@par Inputs: +*One input: +*x: A Tensor of type float32. \n + +*@par Attributes: +*@li min: An optional attribute. Defaults to "-6.0". +*@li max: An optional attribute. Defaults to "6.0". +*@li num_bits: An optional attribute. Defaults to "8". +*@li narrow_range: An optional bool. Defaults to "false". \n + +*@par Outputs: +*y: A Tensor. Has the same shape and type of "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator FakeQuantWithMinMaxArgs. +*/ +REG_OP(FakeQuantWithMinMaxArgs) + .INPUT(x, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(min, Float, -6.0) + .ATTR(max, Float, 6.0) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxArgs) + +/** +*@brief Computes gradients for a FakeQuantWithMinMaxArgs operation. \n + +*@par Inputs: +*Two inputs, including: \n +*@li gradients: A Tensor of type float32. Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +*@li x: A Tensor of type float32. Has the same type and format as "gradients".\n +* This is the input Tensor of the FakeQuantWithMinMaxArgs operator.\n + +*@par Attributes: +*@li min: An optional attribute. Defaults to "-6.0". +*@li max: An optional attribute. Defaults to "6.0". +*@li num_bits: An optional attribute. Defaults to "8". +*@li narrow_range: An optional bool. Defaults to "False". \n + +*@par Outputs: +*y: A Tensor of type float32. \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator FakeQuantWithMinMaxArgsGradient. +*/ +REG_OP(FakeQuantWithMinMaxArgsGradient) + .INPUT(gradients, TensorType({DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(min, Float, -6.0) + .ATTR(max, Float, 6.0) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxArgsGradient) + +/** +*@brief Fake-quantize the 'inputs' tensor of type float via global float scalars. \n + +*@par Inputs: +*Three inputs, including: +*@li x: A Tensor of type float32. +*@li min: A Tensor of type float32. Has the same type and format as "x". +*@li max: A Tensor of type float32. Has the same type and format as "x".\n +*[min; max] define the clamping range for the inputs data + +*@par Attributes: +*@li num_bits: An optional attribute. Defaults to "8". +*@li narrow_range: An optional bool. Defaults to "False". \n + +*@par Outputs: +*y: A Tensor of type float32. \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator FakeQuantWithMinMaxVars. +*/ +REG_OP(FakeQuantWithMinMaxVars) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(min, TensorType({DT_FLOAT})) + .INPUT(max, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxVars) + +/** +*@brief Computes gradients for a FakeQuantWithMinMaxVars operation. \n + +*@par Inputs: +*Four inputs, including: +*@li gradients: A Tensor of type float32. +*@li x: A Tensor of type float32. +*@li min: A Tensor of type float32. +*@li max: A Tensor of type float32. \n + +*@par Attributes: +*@li num_bits: An integer specifying the quantization bit width. Defaults to "8". +*@li narrow_range: A Boolean specifying whether to use a narrow range for quantization. Defaults to "False". \n + +*@par Outputs: +*@li backprops_wrt_x: A Tensor. Has the same type as input "x". +*@li backprops_wrt_min: A Tensor. Has the same type as input "min". +*@li backprops_wrt_max: A Tensor. Has the same type as input "max". \n + +*@attention Constraints: +*@li "gradients" has the same shape as "x". +*@li "min" and "max" are scalars. +*@li "num_bits" is between 2 and 16 + +*@see Region() + +*@par Third-party framework compatibility +* Compatible with the operator FakeQuantWithMinMaxVarsGradient. +*/ +REG_OP(FakeQuantWithMinMaxVarsGradient) + .INPUT(gradients, TensorType({DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(min, TensorType({DT_FLOAT})) + .INPUT(max, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_x, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_min, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_max, TensorType({DT_FLOAT})) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxVarsGradient) + +/** +*@brief Fake-quantizes the "inputs" tensor of type float +via per-channel floats min and max of shape [d] to "outputs" \n +tensor of same shape as inputs + +*@par Inputs: +*Three inputs, including: +*@li x: A Tensor of type float32. +*@li min: A Tensor of type float32. +*@li max: A Tensor of type float32. \n + +*@par Attributes: +*@li num_bits: An integer specifying the quantization bit width. Defaults to "8". +*@li narrow_range: A Boolean specifying whether to use a narrow range for quantization. Defaults to "False". \n + +*@par Outputs: +*y: A Tensor. Has the same type as input "x". + + +*@attention Constraints: +*@li "min" and "max" have one-dimensional shapes. +*@li "min" has the same last dimension size as "x". "max" has the same last dimension size as "x". +*@li "num_bits" is between 2 and 16 + +*@see Region() + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator FakeQuantWithMinMaxVarsPerChannel. +*/ +REG_OP(FakeQuantWithMinMaxVarsPerChannel) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(min, TensorType({DT_FLOAT})) + .INPUT(max, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxVarsPerChannel) + +/** +*@brief Computes gradients for a FakeQuantWithMinMaxVarsPerChannel operation. \n + +*@par Inputs: +*Four inputs, including: +*@li gradients: A Tensor of type float32. +*@li x: A Tensor of type float32. +*@li min: A Tensor of type float32. +*@li max: A Tensor of type float32. \n + +*@par Attributes: +*@li num_bits: An integer specifying the quantization bit width. Defaults to "8". +*@li narrow_range: A Boolean specifying whether to use a narrow range for quantization. Defaults to "False". \n + +*@par Outputs: +*@li backprops_wrt_x: A Tensor. Has the same type as input "x". +*@li backprops_wrt_min: A Tensor. Has the same type as input "min". +*@li backprops_wrt_max: A Tensor. Has the same type as input "max". \n + +*@attention Constraints: +*@li "gradients" has the same shape as "x". +*@li "min" and "max" have one-dimensional shapes. +*@li "min" has the same last dimension size as "x". "max" has the same last dimension size as "x". "gradients" has the same last dimension size as "x". +*@li "num_bits" is between 2 and 16 + +*@see Region() + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator FakeQuantWithMinMaxVarsPerChannelGradient. +*/ +REG_OP(FakeQuantWithMinMaxVarsPerChannelGradient) + .INPUT(gradients, TensorType({DT_FLOAT})) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(min, TensorType({DT_FLOAT})) + .INPUT(max, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_x, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_min, TensorType({DT_FLOAT})) + .OUTPUT(backprops_wrt_max, TensorType({DT_FLOAT})) + .ATTR(num_bits, Int, 8) + .ATTR(narrow_range, Bool, false) + .OP_END_FACTORY_REG(FakeQuantWithMinMaxVarsPerChannelGradient) + +/** +*@brief Element-wise computes the bitwise AND of "x1" and "x2". \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: int8, int16, +* int32, int64, uint8, uint16, uint32, uint64. Broadcasting is supported. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BitwiseAnd. +*/ +REG_OP(BitwiseAnd) + .INPUT(x1, TensorType::IntegerDataType()) + .INPUT(x2, TensorType::IntegerDataType()) + .OUTPUT(y, TensorType::IntegerDataType()) + .OP_END_FACTORY_REG(BitwiseAnd) + +/** +*@brief Element-wise computes the bitwise OR of "x1" and "x2". \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: int8, int16, +* int32, int64, uint8, uint16, uint32, uint64. Broadcasting is supported. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BitwiseOr. +*/ +REG_OP(BitwiseOr) + .INPUT(x1, TensorType::IntegerDataType()) + .INPUT(x2, TensorType::IntegerDataType()) + .OUTPUT(y, TensorType::IntegerDataType()) + .OP_END_FACTORY_REG(BitwiseOr) + +/** +*@brief Elementwise computes the bitwise XOR of "x1" and "x2". \n + +*@par Inputs: +*Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: int8, int16, int32, int64, uint8, uint16, uint32, uint64. +* The format is NC1HWC0 or ND. Broadcasting is supported. +*@li x2: A Tensor. Has the same type and format as "x1". \n + +*@par Outputs: +*y: Output result. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator BitwiseXor. +*/ +REG_OP(BitwiseXor) + .INPUT(x1, TensorType::IntegerDataType()) + .INPUT(x2, TensorType::IntegerDataType()) + .OUTPUT(y, TensorType::IntegerDataType()) + .OP_END_FACTORY_REG(BitwiseXor) + +/** +*@brief Returns element-wise smallest integer not less than "x". \n + +*@par Inputs: +* x: A Tensor of type float16 or float32 or float64. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Ceil. +*/ +REG_OP(Ceil) + .INPUT(x, TensorType::FloatingDataType()) + .OUTPUT(y, TensorType::FloatingDataType()) + .OP_END_FACTORY_REG(Ceil) + +/** +*@brief Returns element-wise largest integer not greater than "x". \n + +*@par Inputs: +*x: A Tensor of type float16, float32 or double. \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Third-party framework compatibility: +* Compatible with TensorFlow operator Floor. +*/ +REG_OP(Floor) + .INPUT(x, TensorType::FloatingDataType()) + .OUTPUT(y, TensorType::FloatingDataType()) + .OP_END_FACTORY_REG(Floor) + +/** +*@brief Divides "x1/x2" element-wise, rounding toward the +* most negative integer. \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, int32, int64, int8, +* uint8, int16, uint16, double, complex64, complex128. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator FloorDiv. +*/ +REG_OP(FloorDiv) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8, + DT_INT64, DT_INT16, DT_UINT16, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8, + DT_INT64, DT_INT16,DT_UINT16, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8, + DT_INT64, DT_INT16,DT_UINT16, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(FloorDiv) + +/** +*@brief Returns element-wise remainder of division. Consistent with: floor(x1/x2) * x2 + mod(x1, x2) = x1. \n + +*@par Inputs: +* Two inputs, including: +*@li x1: A Tensor. Must be one of the following types: +* int32, int64, float, float16, double +*@li x2: A Tensor. Must have the same type as "x1". +* +*@par Outputs: +*y: Result remainder. + +*@attention Constraints: +*@li x2: The input data does not support 0 +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*requirement of double thousandths in the mini form +*@li Due to different architectures, the calculation results of this operator +*on NPU and CPU may be inconsistent +*@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator FloorMod. +*/ +REG_OP(FloorMod) + .INPUT(x1, TensorType({DT_INT32, DT_INT64, DT_FLOAT, DT_FLOAT16, + DT_DOUBLE})) + .INPUT(x2, TensorType({DT_INT32, DT_INT64, DT_FLOAT, DT_FLOAT16, + DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_FLOAT, DT_FLOAT16, + DT_DOUBLE})) + .OP_END_FACTORY_REG(FloorMod) + +/** +*@brief Computes the power of "x1" to "x2". \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: +* float16, float32, int32, int64, int8, uint8, double, complex64, complex128. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator Pow. +*/ +REG_OP(Pow) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64, DT_INT8, + DT_UINT8, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64, DT_INT8, + DT_UINT8, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64, DT_INT8, + DT_UINT8, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Pow) + +/** +*@brief Return element-wise integer closest to x. \n + +*@par Inputs: +*One input, include: +*x: A mutable Tensor. Must be one of the following types: +* float16, float32, double. \n + +*@par Outputs: +*y: A mutable Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Rint. +*/ +REG_OP(Rint) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(Rint) + +/** +*@brief Rounds the values of a tensor to the nearest integer, element-wise. + * Rounds half to even. \n + +*@par Inputs: +*Inputs including: +*x: A required ND Tensor of type float16, float, int64, double, complex64, + * complex128 or int32. +*@par Outputs: +*y: A required ND Tensor. Has the same data type and shape as "x". +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Round. +*/ +REG_OP(Round) + .INPUT(x, TensorType(DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64, + DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128)) + .OUTPUT(y, TensorType(DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT64, + DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128)) + .OP_END_FACTORY_REG(Round) + +/** +*@brief: Computes sine of "x" element-wise. \n + +*@par Inputs: +*One input: +*x: An ND Tensor. Must be one of the following types: float16, float32, double, + * complex64, complex128, int32, int64 + +*@par Outputs: +*y: An ND Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Sin. +*/ +REG_OP(Sin) + .INPUT(x, TensorType::UnaryDataType()) + .OUTPUT(y, TensorType::UnaryDataType()) + .OP_END_FACTORY_REG(Sin) + +/** +*@brief: Computes tan of "x" element-wise. \n + +*@par Inputs: +*One input: +*x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128, int32, int64 + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator Tan. +*/ +REG_OP(Tan) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(Tan) + +/** +*@brief Returns element-wise remainder of division. \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32, +* double, int32, int64. +* @li x2: A Tensor of the same type as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x1". \n + +*@attention Constraints: +*@li x2: The input data does not support 0 +*@li When NUM exceeds 2048 , the accuracy of operator cannot guarantee the +*requirement of double thousandths in the mini form +*@li Due to different architectures, the calculation results of this operator +*on NPU and CPU may be inconsistent +*@li If shape is expressed as (D1,D2... ,Dn), then D1*D2... *DN<=1000000,n<=8 + +*@par Third-party framework compatibility +*@li Compatible with the TensorFlow operator TruncateMod. +*/ +REG_OP(TruncateMod) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT64, + DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT64, + DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT64, + DT_INT32})) + .OP_END_FACTORY_REG(TruncateMod) + +/** +*@brief Adds 'bias' to 'x'. \n + +*@par Inputs: +*Two inputs, including: +* @li x: A Tensor of type NumberType. Must be one of the following types: float32, float64, int32, uint8, int16, +*int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64. +* @li bias: A 1D Tensor with size the C dimension of value. \n + +*@par Attributes: +*data_format: An optional string. Defaults to "NHWC". \n + +*@par Outputs: +*y: A Tensor with same type as "x". \n + +*@par Third-party framework compatibility +*Compatible with the TensorFlow operator BiasAdd. +*/ +REG_OP(BiasAdd) + .INPUT(x, TensorType::NumberType()) + .INPUT(bias, TensorType::NumberType()) + .OUTPUT(y, TensorType::NumberType()) + .ATTR(data_format, String, "NHWC") + .OP_END_FACTORY_REG(BiasAdd) + +/** +*@brief Returns the index with the smallest value across dimensions of a tensor. \n + +*@par Inputs: +*Two inputs, including: +*@li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, float16, uint32, uint64. +*format is ND. +*@li dimension: A Tensor. Must be one of the following types: int32, int64. Must be in the range [-rank(input x), rank(input x)]. Describes which dimension of the input Tensor to reduce across. +* The format is ND. +*@par Attributes: +*dtype: The output type, either "int32" or "int64". Defaults to "int64". \n + +*@par Outputs: +*y: A Tensor of type "dtype". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator ArgMin. +*/ +REG_OP(ArgMin) + .INPUT(x, TensorType::NumberType()) + .INPUT(dimension, TensorType::IndexNumberType()) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Type, DT_INT64) + .OP_END_FACTORY_REG(ArgMin) + +/** +*@brief Returns the index with the smallest value across dimensions of a tensor. \n + +*@par Inputs: +*One input: + +*x: A Tensor of type float16 or float32 in ND format. \n + +*@par Attributes: +*@li dimension: The dimension of the input Tensor to reduce across. +*@li dtype: An optional attribute, specifying the output data type. Must be "int32". Defaults to "int64". \n + +*@par Outputs: +*y: A Tensor of type dtype. \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator ArgMin. +* +* @par Restrictions: +* Warning: THIS FUNCTION IS DEPRECATED. Please use ArgMin instead. +*/ +REG_OP(ArgMinD) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_INT32})) + .REQUIRED_ATTR(dimension, Int) + .ATTR(dtype, Type, DT_INT64) + .OP_END_FACTORY_REG(ArgMinD) + +/** +*@brief Returns the index with the largest value across axes of a tensor. \n + +*@par Inputs: +* Two inputs, including: +*@li x: A multi-dimensional Tensor of type float16, float32, or int16. +*@li dimension: A Scalar of type int32, specifying the index with the largest value. \n + +*@par Attributes: +*dtype: The output type, either "int32" or "int64". Defaults to "int64". \n + +*@par Outputs: +*y: A multi-dimensional Tensor of type int32 or int64, specifying the index with the largest value. The dimension is one less than that of "x". \n + +*@attention Constraints: +*@li x: If there are multiple maximum values, the index of the first maximum value is used. +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator ArgMax. +*/ +REG_OP(ArgMaxV2) + .INPUT(x, TensorType::NumberType()) + .INPUT(dimension, TensorType::IndexNumberType()) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64})) + .ATTR(dtype, Type, DT_INT64) + .OP_END_FACTORY_REG(ArgMaxV2) + +/** +*@brief Returns the index with the largest value across axes of a tensor. \n + +*@par Inputs: +* One input, including: +*x: A multi-dimensional Tensor of type float16, float32. \n + +*@par Attributes: +*@li dimension: An integer of type int32, specifying the axis information of the index with the maximum value. +*@li dtype: The output type, either "int32" or "int64". Defaults to "int64". \n + +*@par Outputs: +*y: A multi-dimensional Tensor of type int32, specifying the index with the largest value. The dimension is one less than that of "x". \n + +*@attention Constraints: +*@li x: If there are multiple maximum values, the index of the first maximum value is used. +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". \n + +*@par Third-party framework compatibility +* Compatible with TensorFlow operator ArgMax. +* +* @par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(ArgMaxD) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_INT32})) + .REQUIRED_ATTR(dimension, Int) + .ATTR(dtype, Type, DT_INT64) + .OP_END_FACTORY_REG(ArgMaxD) + +/** +*@brief Returns the maximum value of all elements in the input in the given +* dimension. \n + +*@par Inputs: +*One input: \n +*x: A multi-dimensional Tensor of type float16 or float32. \n + +*@par Attributes: +*@li dimension: An integer of type int32, specifying the axis information of +* the index with the maximum value. +*@li keep_dims: A bool, specifying whether to keep dimensions for the output +* Tensor. Defaults to "false". \n + +*@par Outputs: +*@li indice: A multi-dimensional Tensor of type int32, specifying the index. +* (If "keep_dims" is set to "false", the output dimensions are reduced by +* "dimension" compared with that of "x". Otherwise, the output has one fewer +* dimension than "x".) +*@li values: A Tensor, specifying the maximum value. Has the same dimensions +* as "indice" and the same type as "x". \n + +*@attention Constraints: +*@li If there are multiple maximum values, the index of the first maximum +* value is used. +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the +* dimension length of "x". \n + +*@par Third-party framework compatibility +* Compatible with the two output scenarios of PyTorch operator Max (the output +* sequence is opposite to that of PyTorch). +*/ +REG_OP(ArgMaxWithValue) + .INPUT(x, TensorType({DT_FLOAT,DT_FLOAT16})) + .OUTPUT(indice,TensorType({DT_INT32})) + .OUTPUT(values, TensorType({DT_FLOAT,DT_FLOAT16})) + .REQUIRED_ATTR(dimension, Int) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(ArgMaxWithValue) + +/** +*@par Inputs: +*One input: \n +*x: A multi-dimensional Tensor of type float16 or float32. \n + +*@par Attributes: +*@li dimension: An integer of type int32, specifying the axis information of +* the index with the maximum value. +*@li keep_dims: A bool, specifying whether to keep dimensions for the output +* Tensor. Defaults to "false". \n + +*@par Outputs: +*@li indice: A multi-dimensional Tensor of type int32, specifying the index. +* (If "keep_dims" is set to "false", the output dimensions are reduced by +* "dimension" compared with that of "x". Otherwise, the output has one fewer +* dimension than "x".) +*@li values: A Tensor, specifying the minimum value. Has the same dimensions +* as "indice" and the same type as "x". \n + +*@attention Constraints: +*@li If there are multiple minimum values, the index of the first minimum +* value is used. +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the +* dimension length of "x". +*@li Performing the ArgMinWithValue operation on the last axis of float32 data +* is not supported on a mini platform. \n + +*@par Third-party framework compatibility +* Compatible with the two output scenarios of PyTorch operator Min (the output +* sequence is opposite to that of PyTorch). +*/ +REG_OP(ArgMinWithValue) + .INPUT(x, TensorType({DT_FLOAT,DT_FLOAT16})) + .OUTPUT(indice,TensorType({DT_INT32})) + .OUTPUT(values, TensorType({DT_FLOAT,DT_FLOAT16})) + .REQUIRED_ATTR(dimension, Int) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(ArgMinWithValue) + +/** +*@brief Compute elementwise modes, such as 0: PRODUCT, 1: SUM, 2: MAX + +*@par Inputs: +*One input: \n +*x: the list of input data, the type of element in Tensor should be same. +* the max size of x is 32. +* should met one of the following types: float16, float32. It's a dynamic input. \n + +*@par Outputs: +*y: A Tensor. Has the same type and format as "x". \n + +*@par Attributes: +*@li N: A required attribute. the number of input x, max size is 32. Type is int. +*@li model: An optional attribute. Type is int. Defaults to "1". +* "0": product, "1": sum, "2": max. +*@li coeff: A required attribute. Must met all of following rules: +* size of "coeff" must be equal to len("x") or is null. +* the absolute value of "coeff" must less than or equal to 1. +*/ +REG_OP(Eltwise) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .REQUIRED_ATTR(N, Int) + .ATTR(mode, Int, 1) + .ATTR(coeff, ListFloat, {}) + .OP_END_FACTORY_REG(Eltwise) + +/** +*@brief Computes element-wise population count. \n + +*@par Inputs: +*x: A Tensor of type TensorType::IntegerDataType(). \n + +*@par Outputs: +*y: A Tensor of type uint8. \n + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator PopulationCount. +*/ +REG_OP(PopulationCount) + .INPUT(x, TensorType::IntegerDataType()) + .OUTPUT(y, TensorType({DT_UINT8})) + .OP_END_FACTORY_REG(PopulationCount) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Thirteen inputs, including: +* @li input_mul3: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul2: A Tensor. Must be one of the following types: float16, float32. +* @li input_realdiv1: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul1: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul0: A Tensor. Must be one of the following types: float16, float32. +* @li input_realdiv0: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_sub: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_sub1: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Four outputs, including: +* @li y1: A Tensor. Must be one of the following types: float16, float32. +* @li y2: A Tensor. Must be one of the following types: float16, float32. +* @li y3: A Tensor. Must be one of the following types: float16, float32. +* @li y4: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambNextMVWithDecay) + .INPUT(input_mul3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_realdiv1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_realdiv0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_sub, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_sub1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y3, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y4, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambNextMVWithDecay) + +/** +*@brief Confuse real_div, rsqrt, sqrt, maximum, minimum, sub and add. \n + +*@par Inputs: +*Thirteen inputs, including: +* @li input_mul3: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul2: A Tensor of the same type as "input1". +* @li input_realdiv1: A Tensor of the same type as "input1". +* @li input_mul1: A Tensor of the same type as "input1". +* @li input_mul0: A Tensor of the same type as "input1". +* @li input_realdiv0: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul4: A Tensor of the same type as "input1". +* @li mul0_x: A Tensor of the same type as "input1". +* @li mul1_sub: A Tensor of the same type as "input1". +* @li mul2_x: A Tensor of the same type as "input1". +* @li mul3_sub1: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor of the same type as "input1". +* @li add2_y: A Tensor of the same type as "input1". \n + +*@par Outputs: +*Four outputs, including: +*@li y1: A Tensor. Has the same type as "input_mul3". +*@li y2: A Tensor. Has the same type as "input_mul3". +*@li y3: A Tensor. Has the same type as "input_mul3". +*@li y4: A Tensor. Has the same type as "input_mul3". + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambNextMV) + .INPUT(input_mul3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_realdiv1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_realdiv0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_sub, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_sub1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y3, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y4, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambNextMV) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Six inputs, including: +* @li input_square: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul2: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li truediv1_recip: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Two outputs, including: +* @li y1: A Tensor of the same type as "input_square". +* @li y2: A Tensor of the same type as "input_square". \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambNextRight) + .INPUT(input_square, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(truediv1_recip, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambNextRight) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Six inputs, including: +* @li input_greater1: A Tensor. Must be one of the following types: float16, float32. +* @li input_greater_realdiv: A Tensor. Must be one of the following types: float16, float32. +* @li input_realdiv: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul0: A Tensor. Must be one of the following types: float16, float32. +* @li input_mul1: A Tensor. Must be one of the following types: float16, float32. +* @li input_sub: A Tensor. Must be one of the following types: float16, float32. +* @li greater_y: A Tensor. Must be one of the following types: float16, float32. +* @li select_e: A Tensor. Must be one of the following types: float16, float32. +* @li minimum_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*y: A Tensor of the same type as "input_greater1". \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambUpdateWithLr) + .INPUT(input_greater1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_greater_realdiv, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_realdiv, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_mul1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_sub, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(greater_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(select_e, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(minimum_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambUpdateWithLr) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Seven inputs, including: +* @li x1: A Tensor. Must be one of the following types: float16, float32. +* @li x2: A Tensor. Must be one of the following types: float16, float32. +* @li x3: A Tensor. Must be one of the following types: float16, float32. +* @li x4: A Tensor. Must be one of the following types: float16, float32. +* @li x5: A Tensor. Must be one of the following types: float16, float32. +* @li greater_y: A Tensor. Must be one of the following types: float16, float32. +* @li select_e: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*y: A Tensor of the same type as input. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambUpdateWithLrV2) + .INPUT(x1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(x5, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(greater_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(select_e, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambUpdateWithLrV2) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Eleven inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(AdamApplyOneWithDecay) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneWithDecay) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(AdamApplyOne) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOne) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Eleven inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(AdamApplyOneWithDecayAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneWithDecayAssign) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(AdamApplyOneAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneAssign) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li steps: A Tensor. Must be one of the following types: float16, float32. +* @li do_use_weight: A Tensor. Must be one of the following types: float16, float32. +* @li weight_decay_rate: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambApplyOptimizerAssign) + .INPUT(grad, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(inputv, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(inputm, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(steps, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(do_use_weight, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(weight_decay_rate, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(inputv, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(inputm, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambApplyOptimizerAssign) + +/** +*@brief A fusion operator for bert lamb. \n + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li steps: A Tensor. Must be one of the following types: float16, float32. +* @li do_use_weight: A Tensor. Must be one of the following types: float16, float32. +* @li weight_decay_rate: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*No outputs +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(LambApplyWeightAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input_param, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(input_param, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(LambApplyWeightAssign) + +/** +*@brief Confuse select, maximum, greater and sqrt. \n + +*@par Inputs: +*Four inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32. +* @li greater_zeros: A Tensor. Must be one of the following types: float16, float32. +* @li select_ones: A Tensor. Must be one of the following types: float16, float32. +* @li maximum_ones: A Tensor. Must be one of the following types: float16, float32. \n + +*@par Outputs: +*y: A Tensor of the same type as "x". \n + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(ClipByNormNoDivSum) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(greater_zeros, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(select_ones, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(maximum_ones, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(ClipByNormNoDivSum) + +/** +*@brief Confuse reducesumd and square. \n + +*@par Inputs: +*x: A Tensor of type float16, float32. \n + +*@par Attributes: +* Two attributes, including: \n +*@li axis: A optional listint, specifies the dimensions to reduce. +*@li keep_dims: A bool, specifying whether to keep dimensions for the output Tensor. Defaults to "false". \n + +*@par Outputs: +*Two outputs, including: \n +*@li y1: A Tensor. Has the same type as "x". +*@li y2: A Tensor. Has the same type as "x". + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(SquareSumV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y2, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(axis, ListInt) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(SquareSumV2) + +/** +*@brief Confuse reducesumd and square. \n + +*@par Inputs: +*x: A Tensor of type float16, float32. \n + +*@par Attributes: +* Two attributes, including: \n +*@li axis: A optional listint, specifies the dimensions to reduce. +*@li keep_dims: A bool, specifying whether to keep dimensions for the output Tensor. Defaults to "false". \n + +*@par Outputs: +y: A Tensor. Has the same type as "x". + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(SquareSumV1) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(axis, ListInt) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(SquareSumV1) + +/** +*@brief Calculate square of Tensor and then reducesum + +*@par Inputs: +*x1: A Tensor of type float32. +*x2: A Tensor of type float32. \n + +*@par Outputs: +y1: A Tensor. Has the same type as "x1".The result of "x1". +y2: A Tensor. Has the same type as "x2".The result of "x2". + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(SquareSumAll) + .INPUT(x1, TensorType({DT_FLOAT})) + .INPUT(x2, TensorType({DT_FLOAT})) + .OUTPUT(y1, TensorType({DT_FLOAT})) + .OUTPUT(y2, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(SquareSumAll) + +/** +*@brief Confuse broadcast, addn and mul. \n + +*@par Inputs: +*Three inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32, int16, float16, float32. +* @li x2: A Tensor of the same type as "x1". +* @li x3: A Tensor of the same type as "x1". \n + +*@par Outputs: +* y: A Tensor. Has the same type as "x1". + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(FusedMulAddN) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .INPUT(x3, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .OP_END_FACTORY_REG(FusedMulAddN) + +/** +*@brief Add 'bias' to 'x'. \n + +*@par Inputs: +* Two inputs, including: +*@li x: An ND tensor of type float16 or float32. +*@li bias: An ND tensor of type float16 or float32. \n + +*@par Attributes: +*@li axis: An optional int32 used to compute the shape of bias input from the online bottoms. Defaults to "1". +*@li num_axes: An optional int32 used to compute the shape of bias input from a Caffe model trained offline. Defaults to "1". +*@li bias_from_blob: An optional bool. If "true", bias is input from a Caffe model trained offline. If "false", bias is input from online bottoms. Defaults to "true". \n + +*@par Outputs: +*y: An ND tensor of type float16 or float32. \n + +*@attention Constraints:\n +* Assume that the shape length of "x" is "n" and that of "bias" is "m". +*@li "axis" is within the range [-n, n-1]. num_axes >= -1. +*@li If "bias_from_blob = true", "num_axes = -1", and "axis >= 0", the ith axis of "bias" and the (i+"axis")th axis of "x" must have the same size (0 <= i < n-axis).\n +* If "axis < 0", the ith axis of "bias" and the (i+n+"axis")th axis of "x" must have the same size (0 <= i < -axis). +*@li If "bias_from_blob = true" and "num_axes = 0", "bias" is a scalar with shape length 1 and dimension size 1. +*@li If "bias_from_blob = true", "num_axes > 0, and "axis >= 0", "axis + num_axes" must be less than or equal to "n" and the ith axis of "bias" and the (i+"axis")th axis of "x" must have the same size (0 <= i < num_axes).\n +* If "axis < 0", "n + axis + num_axes" must be less than or equal to "n" and the ith axis of "bias" and the (i+n+"axis")th axis of "x" must have the same size (0 <= i < num_axes). +*@li If "bias_from_blob = false", "bias" is not a scalar, and "axis >= 0","axis + m" must be less than or equal to "n" and the ith axis of "bias" and the (i+"axis")th axis of "x" must have the same size (0 <= i < m).\n +* If "axis < 0", "n + axis + m" must be less than or equal to "n" and the ith axis of "bias" and the (i+n+"axis")th axis of "x" must have the same size (0 <= i < m). +*@par Third-party framework compatibility +* Compatible with the Caffe operator Bias. +*/ + +REG_OP(Bias) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) /* "First operand." */ + .INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16})) /* "Second operand." */ + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) /* "Result, has same element type as x" */ + .ATTR(axis, Int, 1) + .ATTR(num_axes, Int, 1) + .ATTR(bias_from_blob, Bool, true) + .OP_END_FACTORY_REG(Bias) + +/** +*@brief Function multiply gradients calculation. +output0 is the result of which input0 dot multily input1. +output1 is the result of which input0 dot multily input1, then reducesum it. \n + +*@par Inputs: +*@li input0: A Tensor of input of mul, and dtype supports float16, float32. +*@li input1: A Tensor of input of mul and mul_1, and dtype supports float16, float32. +*@li input2: A Tensor of input of mul_1, and dtype supports float16, float32. \n + +*@par Attributes: +*@li axes: The dimensions to reduce. Default:(), reduce all dimensions. \n +Only constant value is allowed. +*@li keep_dims: If true, keep these reduced dimensions and the length is 1. \n +If false, don’t keep these dimensions. Default:False. \n + +*@par Outputs: +*@li output0: A Tensor result of which input0 dot multily input1. +*@li output1: A Tensor result of which input0 dot multily input1, then reducesum it. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(ConfusionMulGrad) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .ATTR(axes, ListInt, {}) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(ConfusionMulGrad) + +/** +*@brief Function fused multiply l2 loss calculation. \n + +*@par Inputs: +*@li x1: A Tensor of type float16, float32. +*@li x2: A Tensor of type float16, float32. +*@li x3: A Tensor of type float16, float32. \n + +*@par Outputs: +*@li y1: A Tensor of shape and dtype of first output, which should have \n +shape (1,) and dtype as input. +*@li y2: A Tensor of shape and dtype of second output, should be same shape and type as input. + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(FusedMulAddNL2loss) + .INPUT(x1, TensorType::NumberType()) + .INPUT(x2, TensorType::NumberType()) + .INPUT(x3, TensorType::NumberType()) + .OUTPUT(y1, TensorType::NumberType()) + .OUTPUT(y2, TensorType::NumberType()) + .OP_END_FACTORY_REG(FusedMulAddNL2loss) + +/** +*@brief Tests whether the input exceeds a threshold. \n + +*@par Inputs: +*@li x: A Tensor with any format. Must be one of the following types: float16, float32. \n + +*@par Attributes: +*@li threshold: A required float32. Defaults to "0.0". "x" is compared with "threshold", outputs "1" for inputs above threshold; "0" otherwise. \n + +*@par Outputs: +*@li y: A Tensor with any format. Has the same type as the input. Must be one of the following types: float16, float32. +*@par Third-party framework compatibility +* Compatible with the Caffe operator Threshold. +*/ + + REG_OP(Threshold) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) + .ATTR(threshold, Float, 0.0) + .OP_END_FACTORY_REG(Threshold); + +/** +*@brief Returns the index number corresponding to the maximum value entered. \n + +*@par Inputs: +*@li x: A tensor. Must be one of the following types: float16, float32. \n + +*@par Attributes: +*@li axis: An optional int. Specify the axis to be cut at the input tensor. If this parameter is not provided, find the topk for each batch. Defaults to 10000 +*@li out_max_val: An optional bool. Whether to output the maximum value. If it is True, the maximum value and index are output, otherwise only the index is output. +* Defaults to False +*@li topk: An optional int. It means the number of top tok in each axis (the value is greater than or equal to 1), and the value range must be in [1,x.shape(axis)]. +* Defaults to 1 + +*@par Outputs: +*@li indices: A tensor of type float16, float32, int32. The index of the maximum value of the output. +*@li values: A tensor of type float16, float32.Output tensor, including maximum index or maximum value. +*@par Third-party framework compatibility +* Compatible with the Caffe operator ArgMax. +*/ +REG_OP(ArgMaxWithK) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(indices, TensorType({DT_INT32, DT_FLOAT, DT_FLOAT16})) + .OUTPUT(values, TensorType({DT_FLOAT, DT_FLOAT16})) + .ATTR(axis, Int, 10000) + .ATTR(out_max_val, Bool, false) + .ATTR(topk, Int, 1) + .OP_END_FACTORY_REG(ArgMaxWithK) + +/** +*@brief Multiply tensor with scale. \n + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. \n + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator muls. +*/ +REG_OP(Muls) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value, Float) + .OP_END_FACTORY_REG(Muls) + +/** +*@brief Fill tensor with scale. \n + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. \n + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator fills. +*/ +REG_OP(Fills) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value,Float) + .OP_END_FACTORY_REG(Fills) + +/** +*@brief Add tensor with scale. \n + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. \n + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator adds. +*/ + REG_OP(Adds) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value,Float) + .OP_END_FACTORY_REG(Adds) + +/** +*@brief Computes the product of x and y and returns 0 if the y is zero, even if x is NaN or infinite. \n + +*@par Inputs: +* @li x1: A Tensor. Must be one of the following types:float16, float32, double, complex64, complex128. +* @li x2: A Tensor. Has the same type and shape as "x1". \n + +*@par Outputs: +*y: A Tensor. Has the same type and shape as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the TensorFlow operator MulNoNan. +*/ + REG_OP(MulNoNan) + .INPUT(x1, TensorType::NumberType()) /* "First operand." */ + .INPUT(x2, TensorType::NumberType()) /* "Second operand." */ + .OUTPUT(y, TensorType::NumberType()) /* "Result, has same element type as two inputs" */ + .OP_END_FACTORY_REG(MulNoNan) + +/** +*@brief Add tensor with scale. \n + +*@par Inputs: +* @li x1: A Tensor dtype of int32, float16, float32. +* @li x2: A Tensor dtype of int32, float16, float32. \n + +*@par Attributes: +*alpha: Float scalar apply to x2:x2*alpha + +*@par Outputs: +*y: A Tensor. should be same shape and type as "x1". \n + +*@par Third-party framework compatibility: +* Compatible with the Pytorch operator Axpy. +*/ +REG_OP(Axpy) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .REQUIRED_ATTR(alpha, Float) + .OP_END_FACTORY_REG(Axpy) + +/** +*@brief Creates a criterion that measures the loss given input tensors x1 x2 and a Tensor label y with values 1 or -1. \n + +*@par Inputs: +*@li x1: A ND Tensor with one of the following types: int8, uint8, int32, float16, float32. +*@li x2: A ND Tensor with one of the following types: int8, uint8, int32, float16, float32. +*@li target: A ND Tensor with one of the following types: int8, int32, float16, float32. \n + +*@par Attributes: +*@li margin: A optional float32. Defaults to "0.0". +*@li reduction: A optional string. Defaults to "mean". \n + +*@par Outputs: +*@li y: A ND Tensor with Must be float32. +*@par Third-party framework compatibility +* Compatible with the PyTorch operator CosineEmbeddingLoss. +*/ +REG_OP(CosineEmbeddingLoss) + .INPUT(x1, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(x2, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(target, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(margin, Float, 0) + .ATTR(reduction, String, "mean") + .OUTPUT(y, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(CosineEmbeddingLoss) + +/** +*@brief Kullback-Leibler divergence. \n + +*@par Inputs: +* Two inputs, including: +*@li x: Tensor of arbitrary shape. +*@li target: Tensor of the same shape and dtype as x. \n + +*@par Attributes: +*reduction: An required "string", Specifies the reduction to apply to the output; +* Reduction only supports the two modes of "sum" and "batchmean". \n + +*@par Outputs: +*y: A ND Tensor of the same dtype as x. +*@par Third-party framework compatibility +*Compatible with the PyTorch operator kl_div. +*/ +REG_OP(KLDiv) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(target, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(reduction, String) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(KLDiv) + +/** +*@brief copy data from x to y.. \n + +*@par Inputs: +*One inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int32, bool. \n + +*@par Outputs: +*y: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility + +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +*/ +REG_OP(TensorMove) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OP_END_FACTORY_REG(TensorMove) + +/** +*@brief copy data from x to x. \n + +*@par Inputs: +*One inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. \n + +*@par Outputs: +*output_x: A Tensor. Has the same type as "x". \n + +*@par Third-party framework compatibility +*/ +REG_OP(TensorRedirect) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8, + DT_INT64, DT_INT16, DT_UINT16, DT_UINT64, DT_UINT32})) + .OUTPUT(output_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8, + DT_INT64, DT_INT16, DT_UINT16, DT_UINT64, DT_UINT32})) + .OP_END_FACTORY_REG(TensorRedirect) + +/** +* @brief Performs the element-wise division of tensor x2 by tensor x3, +* multiply the result by the scalar value and add it to tensor x1 + +* @par Inputs: +* Three inputs, including: +* @li input_data: A mutable input Tensor. Must be one of the following types: +* float16, float32. +* @li x1: A mutable input Tensor of the same type as x1. +* @li x2: A mutable input Tensor of the same type as x1. +* @li value: A mutable input Tensor. Must be one of the following types: +* float16, float32, int32. \n + +* @par Outputs: +* @li y: A mutable Tensor. Has the same type as "x1". \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator Addcdiv. +*/ +REG_OP(Addcdiv) + .INPUT(input_data, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(value, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT32 })) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(Addcdiv) + +/** +* @brief Performs the element-wise multiplication of tensor x2 by tensor x3, +* multiply the result by the scalar value and add it to tensor input_data + + +* @par Inputs: +* Three inputs, including: +* @li input_data: A mutable input Tensor. Must be one of the following types: +* float16, float32, int8, int32, uint8. +* @li x1: A mutable input Tensor of the same type as x1. +* @li x2: A mutable input Tensor of the same type as x1. +* @li value: A tensor which includes only one element of the same type as x1. \n + +* @par Outputs: +* @li y: A mutable output Tensor. Has the same type as "x1". \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator Addcmul. +*/ +REG_OP(Addcmul) + .INPUT(input_data, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8 })) + .INPUT(x1, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8 })) + .INPUT(x2, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8 })) + .INPUT(value, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8 })) + .OUTPUT(y, TensorType({ DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT8 })) + .OP_END_FACTORY_REG(Addcmul) + +/** +* @brief Computes the result of x2 * alpha + x1. + +* @par Inputs: +* @li x1: An ND tensor of type float16, float32, int32. +* @li x2: An ND tensor of type float16, float32, int32. +* @li alpha: A scalar tensor of type float16, float32. \n + +* @par Outputs: +* @li y: An ND tensor tensor with the same shape and type as "x1". \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator Axpy. +*/ +REG_OP(AxpyV2) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(alpha, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OP_END_FACTORY_REG(AxpyV2) + +/** +* @brief Computes the result of x1 + x2. + +* @par Inputs: +* @li x1: An ND tensor of type float16, float, int32. +* @li x2: An ND tensor of type float16, float, int32. \n + +* @par Outputs: +* @li y: An ND tensor tensor with the same type as "x1". \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator Add. +*/ +REG_OP(PtAdd) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OP_END_FACTORY_REG(PtAdd) + +/** +* @brief Computes the result of x1 * x2. + +* @par Inputs: +* @li x1: An ND tensor of type float16, float32, int32. +* @li x2: An ND tensor of type float16, float32, int32. \n + +* @par Outputs: +* @li y: Same shape and type as the largest ND tensor in x1 x2. \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator muls. +*/ +REG_OP(PtMuls) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OP_END_FACTORY_REG(PtMuls) + +/** +* @brief Computes the result of x1 - x2. + +* @par Inputs: +* @li x1: An ND tensor of type float16, float, int32. +* @li x2: An ND tensor of type float16, float, int32. \n + +* @par Outputs: +* @li y: An ND tensor tensor with the same type as "x1". \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch operator Sub. +*/ +REG_OP(PtSub) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OP_END_FACTORY_REG(PtSub) + +/** +* @brief Add the partial values of two tensors in format NC1HWC0. + +* @par Inputs: +* @li x1: A Tensor in 5HD, and must be one of the following types: float16, +* float32. \n +* @li x2: A Tensor of the same type as "x1", and the same shape as "x1", +* except for the C1 value. \n + +* @par Attributes: +* @li x1_c1_offset: A required int. Offset value of C1 in "x1". \n +* @li x2_c1_offset: A required int. Offset value of C1 in "x2". \n +* @li c1_len: A required int. C1 len of "y". The value must be less than +* the difference between C1 and offset in "x1" and "x2". \n + +* @par Outputs: +* @li y: A Tensor of the same type as "x1", and the same shape as "x1", +* except for the C1 value. Record the result after adding. \n +*/ +REG_OP(StrideAdd) + .INPUT(x1, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .INPUT(x2, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .OUTPUT(y, TensorType({ DT_FLOAT, DT_FLOAT16 })) + .REQUIRED_ATTR(x1_c1_offset, Int) + .REQUIRED_ATTR(x2_c1_offset, Int) + .REQUIRED_ATTR(c1_len, Int) + .OP_END_FACTORY_REG(StrideAdd) + +/** +* @brief Compare two tensors are totally equal or not, only output a bool value" + +* @par Inputs: +* Two inputs, including: +* @li input_x: A Tensor. the first tensor. \n +* @li input_y: A Tensor. the second tensor. \n + +* @par Outputs: +* @li output_z: A Tensor. Bool type, compare result of the two inputs. \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch equal operator. \n +*/ +REG_OP(TensorEqual) + .INPUT(input_x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .INPUT(input_y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) + .OUTPUT(output_z, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(TensorEqual) + +/** + * @brief Element-wise min of each of the input tensors (with Numpy-style broadcasting support). + * All inputs and outputs must have the same data type. This operator supports multidirectional + * (i.e., Numpy-style) broadcasting + * + * @par inputs + * one input including: + * @li x: dynamic input A Tensor. Must be one of the following types: float32, float16, double, int32, int64 + * + * @par output + * one output including: + * @li y:A Tensor of the same type as x + * + */ +REG_OP(MaxN) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_FLOAT64, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_FLOAT64, DT_INT32, DT_INT64})) + .OP_END_FACTORY_REG(MaxN) + + +/** + * @brief Calculates x * maske * value. + * + * @par Inputs: + * @li x: An tensor of type float16 or float32, specifying the input to the data layer. + * @li mask: An tensor of type int8 or float16 or float32, be same shape with x. \n + * + * @par Attributes: + * value: A optional float. \n + * + * @par Outputs: + * y: The output tensor of type float16 or float32. + @ li y:A Tensor of the same type and shape as x + * + */ +REG_OP(MaskedScale) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32})) + .INPUT(mask, TensorType({DT_INT8, DT_FLOAT16, DT_FLOAT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) + .REQUIRED_ATTR(value, Float) + .OP_END_FACTORY_REG(MaskedScale) + +/** + * @brief Calculate the lerp function. \n + + * @par Inputs: + * Three inputs, including: + * @li start: A tensor. Must be one of the following types: + * float16, float32. \n + * @li end: A tensor. Must be one of the following types: + * float16, float32. \n + * @li weight: A tensor. Must be one of the following types: + * float16, float32. \n + + * @par Outputs: + * y: A Tensor with the same type and shape of input_x's. \n + + * @par Third-party framework compatibility + * Compatible with the Pytorch operator Lerp. \n + */ +REG_OP(Lerp) + .INPUT(start, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(end, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(weight, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OP_END_FACTORY_REG(Lerp) + +/** +*@brief Returns the num value of abs(x1-x2) > atol+rtol*abs(x2) element-wise. \n + +* +*@par Inputs: +*@li x1: A tensor. Must be one of the following types: float32, int32, uint8, int8, float16 +*@li x2: A tensor of the same type as "x1". +* +*@par Attributes: +* atol: Defaults to "1e-05". +* rtol: Defaults to "1e-03". +* +*@par Outputs: +* num: A tensor of type int32. +* diff: A tensor of type float16. +* +*@par Restrictions: +*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. +* +*/ +REG_OP(DataCompare) + .INPUT(x1, TensorType({ DT_FLOAT16, DT_FLOAT,DT_INT8, DT_UINT8, DT_INT32 })) + .INPUT(x2, TensorType({ DT_FLOAT16, DT_FLOAT,DT_INT8, DT_UINT8, DT_INT32 })) + .OUTPUT(num, TensorType({DT_FLOAT})) + .OUTPUT(diff, TensorType({DT_FLOAT16})) + .ATTR(atol, Float, 1e-5) + .ATTR(rtol, Float, 1e-3) + .OP_END_FACTORY_REG(DataCompare) + +/** +*@brief Hardmax(element in input, axis) = 1 if the element is the first maximum value along the specified axis, 0 +*otherwise The input does not need to explicitly be a 2D vector.The "axis" attribute indicates the dimension along +*which Hardmax will be performed.The output tensor has the same shape and contains the Hardmax values of the +*corresponding input. +* +*@par inputs +*one input including: +*@li x: input A Tensor.Must be one of the following types:float32,float16 +* +*@par Attributes: +*@li axis:A required int attribute that decides which dimension will be used to cal the hard_max +* +*@par output: +*one output including: +*@li y:A Tensor of the same type as x +* +*/ +REG_OP(HardMax) + .INPUT(x, TensorType({ DT_FLOAT16, DT_FLOAT })) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(axis, Int, -1) + .OP_END_FACTORY_REG(HardMax) + +/** +* @brief Computes the dot product (inner product) of two tensors. This function does not broadcast. + +* @par Inputs: +* Two inputs, including: +* @li input_x: A Tensor. the first tensor must be 1d. \n +* @li input_y: A Tensor. the second tensor must be 1d. \n + +* @par Outputs: +* @li output: A Tensor. Result of the two inputs, must be 1d. \n + +* @par Third-party framework compatibility +* Compatible with the Pytorch dot operator. \n +*/ +REG_OP(Dot) + .INPUT(input_x, TensorType({DT_FLOAT, DT_FLOAT16, DT_UINT8, DT_INT8, DT_INT32})) + .INPUT(input_y, TensorType({DT_FLOAT, DT_FLOAT16, DT_UINT8, DT_INT8, DT_INT32})) + .OUTPUT(output, TensorType({DT_FLOAT, DT_FLOAT16, DT_UINT8, DT_INT8, DT_INT32})) + .OP_END_FACTORY_REG(Dot) + +/** +*@brief Returns a new tensor with boolean elements representing \n +*if each element of input is “close” to the corresponding element of other \n + +*@par Inputs: +*Two inputs, including: +* @li x1: A tensor. Must be one of the following types: +* float16, float32, int32. \n +* @li x2: A tensor with the same type and shape of x1's. \n + +*@par Attributes: +*@li rtol: An optional float.Defaults to 1e-05. \n +*@li atol: An optional float.Defaults to 1e-08. \n +*@li equal_nan: An optional bool.Defaults to false. \n + +*@par Outputs: +*y: A Tensor bool with the same shape of x1's. \n + +*@par Third-party framework compatibility +*Compatible with the Pytorch operator isclose. \n +*/ +REG_OP(IsClose) + .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .INPUT(x2, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OUTPUT(y, TensorType({DT_BOOL})) + .ATTR(rtol, Float, 1e-05) + .ATTR(atol, Float, 1e-08) + .ATTR(equal_nan, Bool, false) + .OP_END_FACTORY_REG(IsClose) + +/** +* @brief Returns the reverse tensor of the ArgMax operator of a tensor. \n + +* @par Inputs: +* three input, including: +* var: A Tensor of type float16, float32, int32 or int8. \n +* indices: A Tensor of type int32. \n +* updates: A Tensor of type float16, float32, int32 or int8. \n + +* @par Attributes: +* @li dimension: An integer of type int, specifying the axis information of the index with the maximum value.\n + +* @par Outputs: +* y: A Tensor of type float16, float32, int32 or int8. \n +* +*@attention Constraints: +*@li indices: only support int32,and shape same to "updates" +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". +*@li y:A Tensor, the type and shape is same to "var" \n + +*@par Third-party framework compatibility +* not support all scene like pytorch operator scatter +* exp: +* var.shape=[2,3,4,5], dim=2, the shape of indices and updates should be [2,3,5] +* not support the shape of indices and updates is [2,3,2,5] like pytorch operator scatter. \n +*/ +REG_OP(ArgMaxGrad) + .INPUT(var, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(indices, TensorType({DT_INT32})) + .INPUT(updates, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .REQUIRED_ATTR(dimension, Int) + .OP_END_FACTORY_REG(ArgMaxGrad) + +/** +* @brief Returns the reverse tensor of the ArgMax operator of a tensor. \n + +* @par Inputs: +* three input, including: +* var: A Tensor of type float16, float32, int32 or int8. \n +* indices: A Tensor of type int32. \n +* updates: A Tensor of type float16, float32, int32 or int8. \n +* assist: A Tensor of int32,also a assist matrix and it's shape must match the shape of var \n + +* @par Attributes: +* @li dimension: An integer of type int, specifying the axis information of the index with the maximum value.\n + +* @par Outputs: +* y: A Tensor of type float16, float32, int32 or int8. \n + +*@attention Constraints: +*@li indices: only support int32,and shape same to "updates" +*@li The value range of "dimension" is [-dims, dims - 1]. "dims" is the dimension length of "x". +*@li y:A Tensor, the type and shape is same to "var" \n + +*@par Third-party framework compatibility +* not support all scene like pytorch operator scatter +* exp: +* var.shape=[2,3,4,5], dim=2, the shape of indices and updates should be [2,3,5] +* not support the shape of indices and updates is [2,3,2,5] like pytorch operator scatter. \n +*/ +REG_OP(ArgMaxGradD) + .INPUT(var, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(indices, TensorType({DT_INT32})) + .INPUT(updates, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .INPUT(assist, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8})) + .REQUIRED_ATTR(dimension, Int) + .OP_END_FACTORY_REG(ArgMaxGradD) + +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_INC_ELEWISE_CALCULATION_OPS_H_ diff --git a/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc b/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file mode 100644 index 00000000..3658c729 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc @@ -0,0 +1,234 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file array_ops_shape_fns.cpp + * \brief + */ +#include "array_ops_shape_fns.h" +#include "graph/types.h" +#include "op_log.h" +#include "error_util.h" +#include "common_shape_fns.h" +#include "axis_util.h" + +namespace ge { +static graphStatus PadKnown(Operator& op, const Tensor& paddings_tensor, const int64_t input_dim_num) { + TensorDesc paddings_tensor_desc = paddings_tensor.GetTensorDesc(); + DataType data_type = paddings_tensor_desc.GetDataType(); + std::vector data; + // every dim has 2 element + int64_t element_num = input_dim_num * 2; + data.reserve(element_num); + if (data_type == DT_INT32) { + const int32_t* paddings_data = reinterpret_cast(paddings_tensor.GetData()); + CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < element_num, + OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); + for (int64_t i = 0; i < element_num; ++i) { + data.push_back(static_cast(paddings_data[i])); + } + } else if (data_type == DT_INT64) { + const int64_t* paddings_data = reinterpret_cast(paddings_tensor.GetData()); + CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < element_num, + OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); + for (int64_t i = 0; i < element_num; ++i) { + data.push_back(paddings_data[i]); + } + } else { + string err_msg = ConcatString("paddings data type invalid, ", "should be DT_INT32 or DT_INT64"); + InferShapeOtherErrReport(op.GetName(), err_msg); + OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); + return GRAPH_FAILED; + } + auto dims = op.GetInputDesc(0).GetShape().GetDims(); + std::vector output_dims(input_dim_num, UNKNOWN_DIM); + if (dims != UNKNOWN_SHAPE) { + output_dims.assign(dims.begin(), dims.end()); + } + for (size_t i = 0; i < data.size(); i += 2) { + if ((data[i] < 0) || (data[i + 1] < 0)) { + std::string err_msg = ConcatString("paddings", DebugString(data), " must be non-negative"); + InferShapeOtherErrReport(op.GetName(), err_msg); + OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); + return GRAPH_FAILED; + } + graphStatus status = Add(output_dims[i / 2], data[i] + data[i + 1], output_dims[i / 2]); + if (status != GRAPH_SUCCESS) { + std::string err_msg = ConcatString("the sum input[0] shape", DebugString(dims), " and input[1] value", + DebugString(data), " must be non-negative"); + OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); + return GRAPH_FAILED; + } + } + auto output_desc = op.GetOutputDesc("y"); + output_desc.SetShape(Shape(output_dims)); + + return op.UpdateOutputDesc("y", output_desc); +} + +graphStatus PadShapeFn(Operator& op) { + Shape paddings; + int64_t input_dim_num; + graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); + if (status != GRAPH_SUCCESS) { + ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); + return GRAPH_FAILED; + } + status = WithValue(paddings.GetDim(1), 2, input_dim_num, op.GetName().c_str()); + if (status != GRAPH_SUCCESS) { + ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), + ConcatString(2, " of dim[1]")); + return GRAPH_FAILED; + } + Shape input; + int64_t dim0 = paddings.GetDim(0); + if (dim0 != UNKNOWN_DIM) { + status = WithRank(op.GetInputDesc(0), dim0, input, op.GetName().c_str()); + if (status != GRAPH_SUCCESS) { + ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); + return GRAPH_FAILED; + } + } else if (op.GetInputDesc(0).GetShape().GetDim(0) != 0) { + status = WithValue(dim0, op.GetInputDesc(0).GetShape().GetDimNum(), input_dim_num, op.GetName().c_str()); + if (status != GRAPH_SUCCESS) { + ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); + return GRAPH_FAILED; + } + } + TensorDesc output_desc = op.GetOutputDesc("y"); + Tensor paddings_tensor; + status = op.GetInputConstData("paddings", paddings_tensor); + if (status != GRAPH_SUCCESS) { + if (dim0 != UNKNOWN_DIM) { + std::vector output_shape(dim0, UNKNOWN_DIM); + output_desc.SetShape(Shape(output_shape)); + } else { + output_desc.SetShape(Shape(UNKNOWN_SHAPE)); + } + return op.UpdateOutputDesc("y", output_desc); + } + input_dim_num = paddings_tensor.GetTensorDesc().GetShape().GetDim(0); + status = WithRank(op.GetInputDesc(0), input_dim_num, input, op.GetName().c_str()); + if (status == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "WithRank fail"); + return GRAPH_FAILED; + } + status = WithValue(dim0, input_dim_num, dim0, op.GetName().c_str()); + if (status == GRAPH_FAILED) { + OP_LOGE(op.GetName().c_str(), "WithValue fail"); + return GRAPH_FAILED; + } + return PadKnown(op, paddings_tensor, input_dim_num); +} + +static graphStatus CalcPadGradOutDims(const Shape& input_shape, const Tensor& paddings_tensor, + std::vector& output_dims, const char* op_name) { + graphStatus status; + size_t input_rank = input_shape.GetDimNum(); + if (output_dims.size() < input_rank) { + return GRAPH_FAILED; + } + DataType padding_type = paddings_tensor.GetTensorDesc().GetDataType(); + if (padding_type == DT_INT32) { + const int32_t* paddings_data = reinterpret_cast(paddings_tensor.GetData()); + CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < input_rank, + OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); + for (size_t i = 0; i < input_rank; ++i) { + const int64_t pad0 = static_cast(paddings_data[2 * i]); + const int64_t pad1 = static_cast(paddings_data[(2 * i) + 1]); + if ((pad0 < 0) || (pad1 < 0)) { + OP_LOGE(op_name, "Paddings must be non-negative, pad0= %lld, pad1=%lld.", pad0, pad1); + return GRAPH_FAILED; + } + status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); + if (status != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } else if (padding_type == DT_INT64) { + const int64_t* paddings_data = reinterpret_cast(paddings_tensor.GetData()); + CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < input_rank, + OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); + for (size_t i = 0; i < input_rank; ++i) { + const int64_t pad0 = paddings_data[2 * i]; + const int64_t pad1 = paddings_data[(2 * i) + 1]; + if ((pad0 < 0) || (pad1 < 0)) { + OP_LOGE(op_name, "Paddings must be non-negative, pad0=%lld, pad1=%lld.", pad0, pad1); + return GRAPH_FAILED; + } + status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); + if (status != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + } + } else { + OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus PadGradShapeFn(Operator& op) { + Shape paddings; + graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); + if (status != GRAPH_SUCCESS) { + ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); + return GRAPH_FAILED; + } + int64_t input_rank = paddings.GetDim(0); + TensorDesc output_desc = op.GetOutputDesc("y"); + output_desc.SetDataType(op.GetInputDesc(0).GetDataType()); + if (input_rank == UNKNOWN_DIM) { + OP_LOGE(op.GetName().c_str(), "paddings inputShape of 0 dims is unknown, set out shape unknown."); + output_desc.SetShape(Shape(UNKNOWN_SHAPE)); + return op.UpdateOutputDesc("y", output_desc); + } + + Shape input_shape; + if (WithRank(op.GetInputDesc(0), input_rank, input_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(input_rank)); + return GRAPH_FAILED; + } + + Shape check_shape({input_rank, 2}); + if (Merge(paddings, check_shape, paddings, op.GetName().c_str())) { + string err_msg = ConcatString("merge 1th input shape", DebugString(paddings.GetDims()), " and shape", + DebugString(check_shape.GetDims()), " failed"); + InferShapeOtherErrReport(op.GetName(), err_msg); + OP_LOGE(op.GetName().c_str(), "Input dimension mismatch, inputRank=%lld.", input_rank); + return GRAPH_FAILED; + } + + Tensor paddings_tensor; + if (op.GetInputConstData("paddings", paddings_tensor) != GRAPH_SUCCESS) { + std::vector unknow_dim_vec(input_rank, UNKNOWN_DIM); + OP_LOGE(op.GetName().c_str(), "Get paddings input tensor fail, set outPut shape unknown."); + output_desc.SetShape(Shape(unknow_dim_vec)); + return op.UpdateOutputDesc("y", output_desc); + } + + std::vector output_dims(input_rank); + auto result = CalcPadGradOutDims(input_shape, paddings_tensor, output_dims, op.GetName().c_str()); + if (result != GRAPH_SUCCESS) { + string err_msg = ConcatString("calculate out dims failed,", "please check the validity of input and attribute"); + InferShapeOtherErrReport(op.GetName(), err_msg); + OP_LOGE(op.GetName().c_str(), "Calculation PadGrad out dimensions failed."); + return GRAPH_FAILED; + } + output_desc.SetShape(Shape(output_dims)); + return op.UpdateOutputDesc("y", output_desc); +} +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h b/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file mode 100644 index 00000000..9855ffeb --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file array_ops_shape_fns.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ + +#include "graph/operator.h" + +namespace ge { +/* * + * infer pad op shape + * @param op Operator which need to infershape + * @return status whether infershape success + */ +graphStatus PadShapeFn(Operator& op); + +/* * + * infer pad grad op shape + * @param op Operator which need to infershape + * @return status whether infershape success + */ +graphStatus PadGradShapeFn(Operator& op); +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ diff --git a/tests/st/framework/stub_op_proto/util/axis_util.cc b/tests/st/framework/stub_op_proto/util/axis_util.cc new file mode 100644 index 00000000..7d10aa31 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/axis_util.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file axis_util.cpp + * \brief get the axis value + */ +#include "axis_util.h" +#include "framework/omg/omg_inner_types.h" +#include "framework/common/types.h" + +namespace ge { +AxisUtil::AxisUtil() { + getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared(GetAxisValueByNCHW)}, + {FORMAT_NHWC, std::make_shared(GetAxisValueByNHWC)}, + {FORMAT_NC1HWC0, std::make_shared(GetAxisValueByNC1HWC0)}, + {FORMAT_HWCN, std::make_shared(GetAxisValueByHWCN)}, + {FORMAT_ND, std::make_shared(GetAxisValueByND)}, + {FORMAT_C1HWNCoC0, std::make_shared(GetAxisValueByC1HWNCoC0)}}; +} + +int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { + if (divisor == 0) { + return 0; + } else { + return (dividend + divisor - 1) / divisor; + } +} + +bool AxisUtil::GetAxisValueByOriginFormat(const Format& format, const vector& dimVec, const uint32_t& c0, + vector& axisValue, vector& ndValue) { + auto iterGetAxisFunc = getAxisValueFuncMap.find(format); + if (iterGetAxisFunc == getAxisValueFuncMap.end()) { + LOG_INFO("Can not get axis value of old format %u!", format); + return false; + } + GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; + CHECK_NOTNULL(getAxisFunc); + return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); +} + +bool AxisUtil::HasAxisValueFunc(const Format& format) { + auto iterGetAxisFunc = getAxisValueFuncMap.find(format); + if (iterGetAxisFunc == getAxisValueFuncMap.end()) { + LOG_INFO("Can not get axis value of format %u!", format); + return false; + } + return true; +} + +bool AxisUtil::CheckParams(const vector& originalDimVec, const uint32_t& c0, vector& axisValue, + vector& ndValue) { + ndValue = originalDimVec; + auto dimSize = originalDimVec.size(); + if (dimSize < ge::DIM_DEFAULT_SIZE) { + /* Before this funcion, we should call function PadDimensionTo4. */ + LOG_INFO("Dimension size %zu is invalid.", dimSize); + return false; + } + if (c0 == 0) { + LOG_ERROR("[ERROR]c0 is zero!"); + return false; + } + + return true; +} + +bool AxisUtil::GetAxisValueByND(const vector& originalDimVec, const uint32_t& c0, vector& axisValue, + vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + ndValue = originalDimVec; + /* To differentiate the input datatype of int8 and others */ + axisValue[AXIS_C0] = c0; + if (originalDimVec.size() == NCHW_DIMENSION_NUM) { + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + } + return true; +} + +bool AxisUtil::GetAxisValueByNCHW(const vector& originalDimVec, const uint32_t& c0, vector& axisValue, + vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NCHW to NZ */ + axisValue[AXIS_C0] = c0; + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByNHWC(const vector& originalDimVec, const uint32_t& c0, vector& axisValue, + vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByNC1HWC0(const vector& originalDimVec, const uint32_t& c0, + vector& axisValue, vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), + return false); + + auto dimSize = originalDimVec.size(); + if (dimSize == ge::DIM_DEFAULT_SIZE + 1) { + axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; + axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; + axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; + } else { + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); + axisValue[AXIS_C0] = c0; + axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; + } + + axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; + axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; + return true; +} + +bool AxisUtil::GetAxisValueByHWCN(const vector& originalDimVec, const uint32_t& c0, vector& axisValue, + vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; + axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; + axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); + axisValue[AXIS_Co] = c0; + return true; +} + +bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector& originalDimVec, const uint32_t& c0, + vector& axisValue, vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); + /* C0 Must be set for case ND or 2D-NHWC to NZ */ + axisValue[AXIS_C0] = c0; + CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), + return false); + + axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; + axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; + axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; + axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; + axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; + axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; + return true; +} + +}; // namespace ge diff --git a/tests/st/framework/stub_op_proto/util/axis_util.h b/tests/st/framework/stub_op_proto/util/axis_util.h new file mode 100644 index 00000000..ce7beb07 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/axis_util.h @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file axis_util.h + * \brief get the axis value + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ + +#include +#include +#include + +#include "framework/omg/omg_inner_types.h" +#include "operator.h" +#include "graph/operator_reg.h" + +#include "op_log.h" + +#define LOG_ERROR(format, args...) printf(format, ##args) +#define LOG_INFO(format, args...) printf(format, ##args) +namespace ge { +const uint32_t NCHW_DIMENSION_NUM = 4; + +const int32_t AXIS_NCHW_DIM_N = 0; +const int32_t AXIS_NCHW_DIM_C = 1; +const int32_t AXIS_NCHW_DIM_H = 2; +const int32_t AXIS_NCHW_DIM_W = 3; + +const int32_t AXIS_NHWC_DIM_N = 0; +const int32_t AXIS_NHWC_DIM_H = 1; +const int32_t AXIS_NHWC_DIM_W = 2; +const int32_t AXIS_NHWC_DIM_C = 3; + +const int32_t AXIS_NC1HWC0_DIM_N = 0; +const int32_t AXIS_NC1HWC0_DIM_C1 = 1; +const int32_t AXIS_NC1HWC0_DIM_C0 = 4; +const int32_t AXIS_NC1HWC0_DIM_H = 2; +const int32_t AXIS_NC1HWC0_DIM_W = 3; + +const int32_t AXIS_HWCN_DIM_H = 0; +const int32_t AXIS_HWCN_DIM_W = 1; +const int32_t AXIS_HWCN_DIM_C = 2; +const int32_t AXIS_HWCN_DIM_N = 3; + +const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; +const int32_t AXIS_C1HWNCoC0_DIM_H = 1; +const int32_t AXIS_C1HWNCoC0_DIM_W = 2; +const int32_t AXIS_C1HWNCoC0_DIM_N = 3; +const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; +const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; + +#define CHECK_NOTNULL(val) \ + do { \ + if ((val) == nullptr) { \ + LOG_ERROR("[ERROR]Parameter[%s] must not be null.", #val); \ + return false; \ + } \ + } while (0) + +#define CHECK(cond, log_func, return_expr) \ + do { \ + if (cond) { \ + log_func; \ + return_expr; \ + } \ + } while (0) + +enum AxisValueType { + AXIS_N = 0, + AXIS_C = 1, + AXIS_H = 2, + AXIS_W = 3, + AXIS_C1 = 4, + AXIS_C0 = 5, + AXIS_Co = 6, + AXIS_D = 7, + AXIS_BOTTOM = 8 +}; + +int64_t DivisionCeiling(int64_t dividend, int64_t divisor); + +/* Axis value is arranged as {N,C,H,W,C1,C0,...} */ +/* The first parameter is old shape's dimension, + * second is c0 and third is axis value. */ +using GetAxisValueInfoByFormat = + std::function&, const uint32_t&, std::vector&, std::vector&)>; + +using GetAxisValueInfoByFormatPtr = std::shared_ptr; + +class AxisUtil { + public: + AxisUtil(); + ~AxisUtil(){}; + bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector& dimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + bool HasAxisValueFunc(const ge::Format& format); + + private: + static bool CheckParams(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNCHW(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNHWC(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByNC1HWC0(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByFz(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByHWCN(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByND(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + static bool GetAxisValueByC1HWNCoC0(const std::vector& originalDimVec, const uint32_t& c0, + std::vector& axisValue, std::vector& ndValue); + + /* map of GetAxisValueInfoByFormat, get axis value by different original + * formats. */ + std::map getAxisValueFuncMap; +}; +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ diff --git a/tests/st/framework/stub_op_proto/util/common_shape_fns.cc b/tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file mode 100644 index 00000000..052173f8 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/common_shape_fns.cc @@ -0,0 +1,1038 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file common_shape_fns.cpp + * \brief + */ +#include "common_shape_fns.h" +#include +#include +#include "op_log.h" +#include "graph/utils/op_desc_utils.h" +#include "common/util/error_manager/error_manager.h" + +namespace ge { +graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name) { + if (rank > INT32_MAX) { + OP_LOGE(op_name, "Rank cannot exceed kint32max"); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + std::vector dims = s.GetDims(); + // dim.size() convert to be type int64_t can't overflow + int64_t size = static_cast(dims.size()); + if (!((size >= rank) || (dims == UNKNOWN_SHAPE))) { + OP_LOGE(op_name, "Shape's rank must be at least %lld", rank); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape) { + if (rank > INT32_MAX) { + OP_LOGE("", "Rank cannot exceed kint32max"); + return GRAPH_FAILED; + } + + GeShape s = tensorDesc->GetShape(); + std::vector dims = s.GetDims(); + // dim.size() convert to be type int64_t can't overflow + int64_t size = static_cast(dims.size()); + + if ((dims != UNKNOWN_RANK) && (size < rank)) { + OP_LOGE("", "Shape's rank must be at least %lld, current=%lld", rank, size); + return GRAPH_FAILED; + } + out_shape = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name) { + if (rank > INT32_MAX) { + OP_LOGE(op_name, "Rank cannot exceed int32max"); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + int64_t existing = static_cast(s.GetDimNum()); + + if (s.GetDims() == UNKNOWN_RANK) { + std::vector out_shape(rank, UNKNOWN_DIM); + out = Shape(out_shape); + return GRAPH_SUCCESS; + } + + if (existing != rank) { + OP_LOGE(op_name, "Shape must be rank %lld", rank); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape) { + if (rank > INT32_MAX) { + OP_LOGE("", "Rank cannot exceed int32max"); + return GRAPH_FAILED; + } + + GeShape s = tensorDesc->GetShape(); + int64_t existing = static_cast(s.GetDimNum()); + if (s.GetDims() == UNKNOWN_RANK) { + std::vector out_dims(rank, UNKNOWN_DIM); + out_shape = GeShape(out_dims); + return GRAPH_SUCCESS; + } + + if (existing != rank) { + OP_LOGE("", "Shape must be rank %lld, current=%lld", rank, existing); + return GRAPH_FAILED; + } + out_shape = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape) { + if (rank > INT32_MAX) { + OP_LOGE("", "Rank cannot exceed int32max"); + return GRAPH_FAILED; + } + + GeShape s = tensorDesc->GetShape(); + int64_t existing = static_cast(s.GetDimNum()); + if (s.GetDims() == UNKNOWN_RANK) { + std::vector out_dims(rank, UNKNOWN_DIM); + out_shape = Shape(out_dims); + return GRAPH_SUCCESS; + } + + if (existing != rank) { + OP_LOGE("", "Shape must be rank %lld, current=%lld", rank, existing); + return GRAPH_FAILED; + } + out_shape = Shape(s.GetDims()); + return GRAPH_SUCCESS; +} + +graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name) { + out = value; + if (dim == UNKNOWN_DIM) { + return GRAPH_SUCCESS; + } + + if (dim != value) { + OP_LOGE(op_name, "Dim and value are not equal: %lld != %lld.", dim, value); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out) { + if (dim1 == dim2) { + out = dim1; + return GRAPH_SUCCESS; + } else if (dim2 == UNKNOWN_DIM) { + out = dim1; + return GRAPH_SUCCESS; + } else if (dim1 == UNKNOWN_DIM) { + out = dim2; + return GRAPH_SUCCESS; + } + return GRAPH_FAILED; +} + +graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name) { + // Same shape and unknown rank + if (s0.GetDims() == s1.GetDims()) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s1)) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s0)) { + out = s1; + return GRAPH_SUCCESS; + } + + const size_t rank = s0.GetDimNum(); + if (s1.GetDimNum() != rank) { + OP_LOGE(op_name, "Dimension number of two shapes are not equal: %llu != %llu.", rank, s1.GetDimNum()); + return GRAPH_FAILED; + } + + // Check if each dims equal + bool return_s0 = true; + bool return_s1 = true; + for (size_t i = 0; i < rank; i++) { + int64_t d0 = s0.GetDim(i); + int64_t d1 = s1.GetDim(i); + if (d0 == UNKNOWN_DIM) { + if (d1 != UNKNOWN_DIM) { + return_s0 = false; + } + } else if (d1 == UNKNOWN_DIM) { + return_s1 = false; + } else if (d0 != d1) { + OP_LOGE(op_name, "Dims %llu are not equal.", rank); + return GRAPH_FAILED; + } + } + + if (return_s0 || return_s1) { + out = return_s0 ? s0 : s1; + return GRAPH_SUCCESS; + } + + // Merge dims + std::vector dims(rank, 0); + for (size_t i = 0; i < rank; ++i) { + // Invariant for merge was checked earlier, so CHECK is ok. + if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) { + OP_LOGE(op_name, "Failed to merge dims in rank %llu.", i); + return GRAPH_FAILED; + } + } + + out = Shape(dims); + return GRAPH_SUCCESS; +} + +graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name) { + // Same shape and unknown rank + if (s0.GetDims() == s1.GetDims()) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s1)) { + out = s0; + return GRAPH_SUCCESS; + } else if (!RankKnown(s0)) { + out = s1; + return GRAPH_SUCCESS; + } + + const size_t rank = s0.GetDimNum(); + if (s1.GetDimNum() != rank) { + OP_LOGE(op_name, "Dimension number of two shapes are not equal: %llu != %llu.", rank, s1.GetDimNum()); + return GRAPH_FAILED; + } + + // Check if each dims equal + bool return_s0 = true; + bool return_s1 = true; + for (size_t i = 0; i < rank; i++) { + int64_t d0 = s0.GetDim(i); + int64_t d1 = s1.GetDim(i); + if (d0 == UNKNOWN_DIM) { + if (d1 != UNKNOWN_DIM) { + return_s0 = false; + } + } else if (d1 == UNKNOWN_DIM) { + return_s1 = false; + } else if (d0 != d1) { + OP_LOGE(op_name, "Dims %llu are not equal.", rank); + return GRAPH_FAILED; + } + } + + if (return_s0 || return_s1) { + out = return_s0 ? s0 : s1; + return GRAPH_SUCCESS; + } + + // Merge dims + std::vector dims(rank, 0); + for (size_t i = 0; i < rank; ++i) { + // Invariant for merge was checked earlier, so CHECK is ok. + if (Merge(s0.GetDim(i), s1.GetDim(i), dims[i]) == GRAPH_FAILED) { + OP_LOGE(op_name, "Failed to merge dims in rank %llu.", i); + return GRAPH_FAILED; + } + } + + out = GeShape(dims); + return GRAPH_SUCCESS; +} + +graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name) { + if (!RankKnown(s)) { + out = Shape(ge::UNKNOWN_SHAPE); + return GRAPH_SUCCESS; + } + int64_t dim_index = dim_index_in; + if (dim_index < 0) { + dim_index = (int64_t)s.GetDimNum() + dim_index; + } + if (!FastBoundsCheck(dim_index, s.GetDimNum())) { + out = Shape(); + OP_LOGE(op_name, "Out of range dim_index %ld for shape with %d dimensions", dim_index_in, s.GetDimNum()); + return GRAPH_FAILED; + } + std::vector dims = s.GetDims(); + dims[dim_index] = new_dim; + out = Shape(dims); + return GRAPH_SUCCESS; +} + +graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name) { + if (!RankKnown(s)) { + out = GeShape(UNKNOWN_RANK); + return GRAPH_SUCCESS; + } + int64_t dim_index = dim_index_in; + if (dim_index < 0) { + dim_index = (int64_t)s.GetDimNum() + dim_index; + } + if (!FastBoundsCheck(dim_index, s.GetDimNum())) { + out = GeShape(); + OP_LOGE(op_name, "Out of range dim_index %ld for shape with %d dimensions", dim_index_in, s.GetDimNum()); + return GRAPH_FAILED; + } + std::vector dims = s.GetDims(); + dims[dim_index] = new_dim; + out = GeShape(dims); + return GRAPH_SUCCESS; +} + +template +bool FastBoundsCheck(const Ta index, const Tb limit) { + static_assert(std::is_integral::value && std::is_integral::value, + "FastBoundsCheck can only be used on integer types."); + typedef typename std::make_unsigned::type UIndex; + return static_cast(index) < static_cast(limit); +} + +graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out) { + if (dim1 == 0) { + out = dim2; + } else if (dim2 == 0) { + out = dim1; + } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { + out = UNKNOWN_DIM; + } else { + const int64_t sum = dim1 + dim2; + if (sum < 0) { + return GRAPH_FAILED; + } + out = sum; + } + return GRAPH_SUCCESS; +} + +graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name) { + if (dim2 == 0) { + out = dim1; + } else if ((dim1 == UNKNOWN_DIM) || (dim2 == UNKNOWN_DIM)) { + out = UNKNOWN_DIM; + } else { + if (dim1 < dim2) { + OP_LOGE(op_name, "Negative dimension size caused by subtracting, dim1=%ld, dim2=%ld", dim1, dim2); + return GRAPH_FAILED; + } + out = dim1 - dim2; + } + return GRAPH_SUCCESS; +} + +graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name) { + if (s.GetDimNum() > INT32_MAX) { + OP_LOGE(op_name, "shape rank cannot exceed kint32max"); + return GRAPH_FAILED; + } + const int64_t rank = static_cast(s.GetDimNum()); + TensorDesc tensor(s); + if (start == 0 && ((tensor.GetRealDimCnt() != -1 && end >= rank) || end == std::numeric_limits::max())) { + out = s; + return GRAPH_SUCCESS; + } + + if (start > rank) { + start = rank; + } + if (end > rank) { + end = rank; + } + + if (stride < 0 && start == rank) { + --start; + } + + if (start < 0) { + start = rank + start; + if (start < 0) { + OP_LOGE(op_name, "Subshape start out of bounds must be at least 0"); + return GRAPH_FAILED; + } + } + + if (end < 0) { + end = rank + end; + if (end < 0) { + OP_LOGE(op_name, "Subshape end out of bounds must be at least 0"); + return GRAPH_FAILED; + } + } + + if (!((stride <= 0 || start <= end))) { + OP_LOGE(op_name, "Subshape must have computed start <= end"); + return GRAPH_FAILED; + } + if (!(stride >= 0 || start >= end)) { + OP_LOGE(op_name, "Subshape must have computed start >= end since stride is negative"); + return GRAPH_FAILED; + } + std::vector dims; + for (int64_t i = start; stride > 0 ? i < end : i > end; i += stride) { + dims.push_back(s.GetDim(i)); + } + Shape tmp(dims); + out = tmp; + return GRAPH_SUCCESS; +} + +graphStatus SubShape(const GeShape& src_shape, + int64_t start, + int64_t end, + int64_t stride, + GeShape& out_shape, + const char* op_name) { + int64_t src_rank = src_shape.GetDimNum(); + if (src_rank > static_cast(std::numeric_limits::max())) { + OP_LOGE(op_name, "shape rank cannot exceed kint32max, got rank %lld", src_rank); + return GRAPH_FAILED; + } + + if (start == 0 && stride == 1 && + ((RankKnown(src_shape) && end >= src_rank) || + (end == std::numeric_limits::max()))) { + out_shape = src_shape; + return GRAPH_SUCCESS; + } + + if (start > src_rank) { + start = src_rank; + } + + if (end > src_rank) { + end = src_rank; + } + + if (stride < 0 && start == src_rank) { + --start; + } + + if (start < 0) { + start += src_rank; + if (start < 0) { + OP_LOGE(op_name, "Subshape start %lld out of bounds must be at least 0", + start); + return GRAPH_FAILED; + } + } + + if (end < 0) { + end += src_rank; + if (end < 0) { + OP_LOGE(op_name, "Subshape end %lld out of bounds must be at least 0", + end); + return GRAPH_FAILED; + } + } + + if (stride > 0 && start > end) { + OP_LOGE(op_name, "Subshape must have computed start=%lld <= end=%lld since stride=%lld is positive", + start, end, stride); + return GRAPH_FAILED; + } else if (stride < 0 && start < end) { + OP_LOGE(op_name, "Subshape must have computed start=%lld >= end=%lld since stride=%lld is negative", + start, end, stride); + return GRAPH_FAILED; + } + + std::vector out_dims; + for (int64_t i = start; (stride > 0 ? i < end : i > end); i += stride) { + out_dims.push_back(src_shape.GetDim(i)); + } + out_shape = GeShape(out_dims); + return GRAPH_SUCCESS; +} + +graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out) { + if (!RankKnown(s1) || !RankKnown(s2)) { + out = Shape(ge::UNKNOWN_RANK); + return GRAPH_SUCCESS; + } + size_t s1_rank = s1.GetDimNum(); + size_t s2_rank = s2.GetDimNum(); + size_t rank = s1_rank + s2_rank; + std::vector dims; + dims.reserve(rank); + for (size_t i = 0; i < s1_rank; ++i) { + dims.push_back(s1.GetDim(i)); + } + for (size_t i = 0; i < s2_rank; ++i) { + dims.push_back(s2.GetDim(i)); + } + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out) { + if (!RankKnown(s1) || !RankKnown(s2)) { + out = GeShape(ge::UNKNOWN_RANK); + return GRAPH_SUCCESS; + } + const int64_t s1_rank = s1.GetDimNum(); + const int64_t s2_rank = s2.GetDimNum(); + const int64_t out_rank = s1_rank + s2_rank; + std::vector out_dims; + out_dims.reserve(out_rank); + for (int64_t i = 0; i < s1_rank; ++i) { + out_dims.push_back(s1.GetDim(i)); + } + for (int64_t i = 0; i < s2_rank; ++i) { + out_dims.push_back(s2.GetDim(i)); + } + out = GeShape(out_dims); + return GRAPH_SUCCESS; +} + +graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out) { + std::vector dims; + dims.reserve(2); + dims.push_back(dim1); + dims.push_back(dim2); + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus Vector(int64_t dim, Shape& out) { + std::vector dims; + dims.reserve(1); + dims.push_back(dim); + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +static graphStatus GetShapeDataFromShapeTensor(Operator& op, + const string& dst_name, + int64_t rank, + std::vector& data, + const char* op_name) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + auto shape_data_desc = op_desc->MutableInputDesc(dst_name); + + std::vector input_infer_depends = {dst_name}; + op_desc->SetOpInferDepends(input_infer_depends); + + GeShape shape_data_shape(shape_data_desc->GetShape()); + std::vector dims = shape_data_shape.GetDims(); + DataType data_type = shape_data_desc->GetDataType(); + if (dims.size() != static_cast(rank)) { + OP_LOGE(op_name, "Shape's rank must be [%u], but it is [%u]", + rank, dims.size()); + std::string info = "Shape's rank must be " + std::to_string(rank) + + ", but it is " + std::to_string(dims.size()) + "."; + InferShapeErrorReport(op_name, "shape_data_desc", "shape's rank", info); + return GRAPH_FAILED; + } + int64_t dim_value = ((rank > 0) && (dims[0] > 0)) ? dims[0] : 1; + data.clear(); + data.reserve(dim_value); + Tensor shape_tensor; + if (data_type == DT_INT32) { + if (op.GetInputConstData(dst_name, shape_tensor) == GRAPH_SUCCESS) { + const int32_t* shape_data = + reinterpret_cast(shape_tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(static_cast(shape_data[i])); + } + } else { + OP_LOGI(op.GetName().c_str(), "Input [%s] is not a const tensor.", + dst_name.c_str()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(UNKNOWN_DIM); + } + } + } else if (data_type == DT_INT64) { + if (op.GetInputConstData(dst_name, shape_tensor) == GRAPH_SUCCESS) { + const int64_t* shape_data = + reinterpret_cast(shape_tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(static_cast(shape_data[i])); + } + } else { + OP_LOGI(op.GetName().c_str(), "Input [%s] is not a const tensor.", + dst_name.c_str()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(UNKNOWN_DIM); + } + } + } else { + OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +static graphStatus GetShapeDataFromConstData(const Tensor& tensor, int64_t rank, std::vector& data, + const char* op_name) { + TensorDesc shape_data_desc = tensor.GetTensorDesc(); + Shape shape_data_shape = shape_data_desc.GetShape(); + std::vector dims = shape_data_shape.GetDims(); + DataType data_type = shape_data_desc.GetDataType(); + + if (dims.size() != static_cast(rank)) { + OP_LOGE(op_name, "Shape's rank must be [%u], but it is [%u]", rank, dims.size()); + std::string info = "Shape's rank must be " + std::to_string(rank) + + ", but it is " + std::to_string(dims.size()) + "."; + InferShapeErrorReport(op_name, "shape_data_desc", "shape's rank", info); + return GRAPH_FAILED; + } + int64_t dim_value = rank > 0 ? dims[0] : 1; + data.clear(); + data.reserve(dim_value); + if (data_type == DT_INT32) { + const int32_t* shape_data = reinterpret_cast(tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(static_cast(shape_data[i])); + } + } else if (data_type == DT_INT64) { + const int64_t* shape_data = reinterpret_cast(tensor.GetData()); + for (int64_t i = 0; i < dim_value; i++) { + data.push_back(shape_data[i]); + } + } else { + OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); + return GRAPH_FAILED; + } + + return GRAPH_SUCCESS; +} + +graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name) { + std::vector shape_data; + GetShapeDataFromConstData(tensor, 1, shape_data, op_name); + out = Shape(shape_data); + return GRAPH_SUCCESS; +} + +graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name) { + std::vector shape_data; + GetShapeDataFromShapeTensor(op, dst_name, 1, shape_data, op_name); + out = GeShape(shape_data); + return GRAPH_SUCCESS; +} + +graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name) { + std::vector shape_data; + GetShapeDataFromConstData(tensor, 0, shape_data, op_name); + out = shape_data[0]; + return GRAPH_SUCCESS; +} + +graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name) { + if (rank > INT32_MAX) { + OP_LOGE(op_name, "Rank cannot exceed kint32max"); + return GRAPH_FAILED; + } + Shape s = tensor.GetShape(); + std::vector dims = s.GetDims(); + if (!((dims.size() <= static_cast(rank)) || (dims == ge::UNKNOWN_SHAPE))) { + OP_LOGE(op_name, "Shape's rank must be at most %lld, but it is %u", rank, dims.size()); + return GRAPH_FAILED; + } + out = s; + return GRAPH_SUCCESS; +} + +graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape) { + if (rank > INT32_MAX) { + OP_LOGE("", "Rank cannot exceed kint32max"); + return GRAPH_FAILED; + } + + GeShape s = tensorDesc->GetShape(); + std::vector dims = s.GetDims(); + if ((dims != ge::UNKNOWN_RANK) && (dims.size() > static_cast(rank))) { + OP_LOGE("", "Shape's rank must be at most %lld, but it is %zu", rank, dims.size()); + return GRAPH_FAILED; + } + + out_shape = s; + return GRAPH_SUCCESS; +} + +graphStatus Scalar(Shape& out) { + std::vector dims = {}; + Shape s(dims); + out = s; + return GRAPH_SUCCESS; +} + +graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name) { + TensorDesc desc = op.GetOutputDesc(output_name); + desc.SetShape(op.GetInputDesc(input_name).GetShape()); + return op.UpdateOutputDesc(output_name, desc); +} + +graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out, + const char* op_name) { + if (divisor == 1) { + out = dividend; + } else if ((dividend == ge::UNKNOWN_DIM) || (divisor == ge::UNKNOWN_DIM)) { + out = ge::UNKNOWN_DIM; + } else { + if (divisor <= 0) { + OP_LOGE(op_name, "Devide's divisor must be positive, but it is %lld", divisor); + return GRAPH_FAILED; + } + if (!((!evenlyDivisible) || (dividend % divisor) == 0)) { + OP_LOGE(op_name, "Dimension size must be evenly divisible by %lld, but is %lld", divisor, dividend); + return GRAPH_FAILED; + } + out = dividend / divisor; + } + return GRAPH_SUCCESS; +} + +bool ShapeFullDefined(const Shape& shape) { + if (!RankKnown(shape)) { + return false; + } + std::vector dims = shape.GetDims(); + + for (const auto& dim : dims) { + if (dim == ge::UNKNOWN_DIM) { + return false; + } + } + return true; +} + +bool ShapeFullyDefined(const GeShape& shape) { + if (!RankKnown(shape)) { + return false; + } + + std::vector dims = shape.GetDims(); + for (const int64_t& dim : dims) { + if (dim == ge::UNKNOWN_DIM) { + return false; + } + } + + return true; +} + +bool RankKnown(const Shape& shape) { + std::vector dims = shape.GetDims(); + if (dims == ge::UNKNOWN_RANK) { + return false; + } + return true; +} + +bool RankKnown(const GeShape& shape) { + std::vector dims = shape.GetDims(); + if (dims == ge::UNKNOWN_RANK) { + return false; + } + return true; +} + +Shape UnknownShapeOfRank(int64_t rank) { + std::vector dims(rank); + for (int64_t i = 0; i < rank; ++i) { + dims[i] = ge::UNKNOWN_DIM; + } + return Shape(dims); +} + +bool ValueKnown(const Shape& shape, const size_t& dim_index) { + if (shape.GetDims() == ge::UNKNOWN_SHAPE) { + return false; + } + if (dim_index >= shape.GetDims().size()) { + return false; + } + if (shape.GetDim(dim_index) == ge::UNKNOWN_DIM) { + return false; + } + + return true; +} + +graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape, + const char* op_name) { + // Validate ranks + Shape unused_shape; + if (WithRank(indices, 2, unused_shape, op_name) != GRAPH_SUCCESS) { + OP_LOGE(op_name, "ValidateSparseTensor indices rank must be 2."); + return GRAPH_FAILED; + } + if (WithRank(values, 1, unused_shape, op_name) != GRAPH_SUCCESS) { + OP_LOGE(op_name, "ValidateSparseTensor values rank must be 1."); + return GRAPH_FAILED; + } + if (WithRank(shape, 1, unused_shape, op_name) != GRAPH_SUCCESS) { + OP_LOGE(op_name, "ValidateSparseTensor shape rank must be 1."); + return GRAPH_FAILED; + } + + // Number of elements in indices and values must match + Shape indices_shape = indices.GetShape(); + Shape values_shape = values.GetShape(); + if (ValueKnown(indices_shape, 0)) { + if (ValueKnown(values_shape, 0)) { + if (indices_shape.GetDim(0) != values_shape.GetDim(0)) { + OP_LOGE(op_name, "Number of elements in index and values do not match."); + return GRAPH_FAILED; + } + } + } + + // Rank embedded in indices must match shape. + Shape sparse_shape = shape.GetShape(); + if (ValueKnown(indices_shape, 1)) { + if (ValueKnown(sparse_shape, 0)) { + if (indices_shape.GetDim(1) != sparse_shape.GetDim(0)) { + OP_LOGE(op_name, "Index rank and shape rank do not match."); + return GRAPH_FAILED; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus DecodeWavShapeFn(Operator& op) { + Shape unused_shape; + if (WithRank(op.GetInputDesc(0), 0, unused_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input must be scalar."); + return GRAPH_FAILED; + } + + int64_t channels_dim = 0; + int32_t desired_channels = 0; + if (op.GetAttr("desired_channels", desired_channels) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "GetAttr desired_channels error."); + return GRAPH_FAILED; + } + if (desired_channels == -1) { + channels_dim = ge::UNKNOWN_DIM; + } else { + if (desired_channels < 0) { + OP_LOGE(op.GetName().c_str(), "channels must be non-negative."); + return GRAPH_FAILED; + } + + channels_dim = static_cast(desired_channels); + } + int64_t samples_dim; + int32_t desired_samples; + if (op.GetAttr("desired_samples", desired_samples) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "GetAttr desired_samples error."); + return GRAPH_FAILED; + } + if (desired_samples == -1) { + samples_dim = ge::UNKNOWN_DIM; + } else { + if (desired_samples < 0) { + OP_LOGE(op.GetName().c_str(), "samples must be non-negative."); + return GRAPH_FAILED; + } + samples_dim = static_cast(desired_samples); + } + + Shape audio_shape({samples_dim, channels_dim}); + Shape sample_rate_shape; + (void)Scalar(sample_rate_shape); + TensorDesc audio_tensor = op.GetOutputDesc("audio"); + audio_tensor.SetDataType(DT_FLOAT); + audio_tensor.SetShape(audio_shape); + (void)op.UpdateOutputDesc("audio", audio_tensor); + TensorDesc sample_rate_tensor = op.GetOutputDesc("sample_rate"); + sample_rate_tensor.SetDataType(DT_INT32); + sample_rate_tensor.SetShape(sample_rate_shape); + return op.UpdateOutputDesc("sample_rate", sample_rate_tensor); +} + +graphStatus EncodeWavShapeFn(Operator& op) { + Shape unused_shape; + if (WithRank(op.GetInputDesc(0), 2, unused_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "Input audio must be rank 2."); + return GRAPH_FAILED; + } + if (WithRank(op.GetInputDesc(1), 0, unused_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "Input sample_rate must be scalar."); + return GRAPH_FAILED; + } + + Shape output_shape; + (void)Scalar(output_shape); + TensorDesc contents_tensor = op.GetOutputDesc("contents"); + contents_tensor.SetDataType(DT_STRING); + contents_tensor.SetShape(output_shape); + return op.UpdateOutputDesc("contents", contents_tensor); +} + +graphStatus SparseSegmentReductionShapeFn(Operator& op) { + Shape x_shape; + if (WithRankAtLeast(op.GetInputDesc(0), 1, x_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input x should be at least 1-D."); + return GRAPH_FAILED; + } + Shape indices_shape; + if (WithRank(op.GetInputDesc(1), 1, indices_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input indices must be 1-D."); + return GRAPH_FAILED; + } + Shape segment_ids_shape; + if (WithRank(op.GetInputDesc(2), 1, segment_ids_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input segment_ids must be 1-D."); + return GRAPH_FAILED; + } + Shape unused; + if (Merge(indices_shape, segment_ids_shape, unused, op.GetName().c_str()) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + Shape subshape; + if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op.GetName().c_str()) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + Shape out; + Shape unknown_dim_shape({ge::UNKNOWN_DIM}); + if (Concatenate(unknown_dim_shape, subshape, out) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + TensorDesc out_desc = op.GetOutputDesc(0); + out_desc.SetDataType(op.GetInputDesc(0).GetDataType()); + out_desc.SetShape(out); + if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "update y failed"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus SparseSegmentReductionGradShapeFn(Operator& op) { + Shape x_shape; + if (WithRankAtLeast(op.GetInputDesc(0), 1, x_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input x should be at least 1-D."); + return GRAPH_FAILED; + } + Shape indices_shape; + if (WithRank(op.GetInputDesc(1), 1, indices_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input indices must be 1-D."); + return GRAPH_FAILED; + } + Shape unused; + Shape segment_ids_shape = op.GetInputDesc(2).GetShape(); + if (Merge(segment_ids_shape, indices_shape, unused, op.GetName().c_str()) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + if (WithRank(op.GetInputDesc(3), 0, unused, op.GetName().c_str()) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "input output_dim0 must be scalar."); + return GRAPH_FAILED; + } + Shape subshape; + if (SubShape(x_shape, 1, x_shape.GetDimNum(), 1, subshape, op.GetName().c_str()) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + Tensor dims0_tensor; + Shape dim0_shape; + op.GetInputConstData("output_dim0", dims0_tensor); + const uint8_t* dims0 = dims0_tensor.GetData(); + const int32_t* dims0_data = reinterpret_cast(dims0); + if (*dims0_data < 0) { + OP_LOGE(op.GetName().c_str(), "Cannot specify a negative value for output_dim0."); + return GRAPH_FAILED; + } + dim0_shape = Shape({*dims0_data}); + + Shape out; + if (Concatenate(dim0_shape, subshape, out) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + TensorDesc out_desc = op.GetOutputDesc(0); + out_desc.SetDataType(op.GetInputDesc(0).GetDataType()); + out_desc.SetShape(out); + if (op.UpdateOutputDesc("y", out_desc) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "update y failed"); + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + +graphStatus ValidateVariableResourceHandle(Operator& op, std::vector& shape_and_type) { + auto input_handle = op.GetInferenceContext()->GetInputHandleShapesAndTypes(); + if (input_handle.empty()) { + Shape unknown_shape(ge::UNKNOWN_SHAPE); + ShapeAndType shape_and_type(unknown_shape, DT_UNDEFINED); + std::vector handle_shapes_and_types; + handle_shapes_and_types.reserve(1); + handle_shapes_and_types.emplace_back(shape_and_type); + input_handle.emplace_back(handle_shapes_and_types); + } else { + shape_and_type = input_handle[0]; + DataType value_type; + if (op.GetAttr("dtype", value_type) != GRAPH_SUCCESS) { + OP_LOGE(op.GetName().c_str(), "GetAttr dtype failed."); + return GRAPH_FAILED; + } + if (shape_and_type[0].GetDataType() != value_type) { + OP_LOGE(op.GetName().c_str(), "ValidateVariableResourceHandle read variable with wrong dtype"); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type) { + if (RankKnown(shape)) { + auto dims = shape.GetDims(); + bool shape_fully_defined = true; + for (const int64_t& dim : dims) { + if (dim == UNKNOWN_DIM) { + shape_fully_defined = false; + break; + } + } + if (!shape_fully_defined) { + std::vector> shape_range; + for (const int64_t& dim : dims) { + shape_range.push_back(dim == UNKNOWN_DIM ? std::pair{1, -1} : + std::pair{dim, dim}); + } + op_desc->SetShapeRange(shape_range); + } + } + op_desc->SetShape(shape); + op_desc->SetDataType(data_type); +} + +void InferShapeErrorReport(const std::string& op_name, const std::string& op_type, const std::string& value, + const std::string& reason) { + std::string report_error_code = "E14001"; + ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"}, + {op_name, op_type, value, reason}); +} + +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/util/common_shape_fns.h b/tests/st/framework/stub_op_proto/util/common_shape_fns.h new file mode 100644 index 00000000..bdb18197 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/common_shape_fns.h @@ -0,0 +1,417 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file common_shape_fns.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ + +#include +#include +#include "graph/tensor.h" +#include "graph/operator.h" +#include "graph/op_desc.h" +#include "graph/ge_tensor.h" +#include "error_code.h" + +namespace ge { + +/** + * Check whether Shape's rank is at least rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); + +/** + * Check whether Shape's rank is at least rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); + +/** + * Check whether Shape's rank is equal to rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); + +/** + * Check whether Shape's rank is equal to rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); + +/** + * Check whether Shape's rank is equal to rank + * @param tensor Input tensor + * @param rank expect val of Shape + * @param out Output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape); + +/** + * Check whether dim is equal to value + * @param dim Input dim + * @param value expect val of dim + * @param out Output dim + * @return status whether Dim is equal to value + */ +graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name); + +/** + * Merge two dims of Shape + * @param dim0 first dim val + * @param dim1 second dim val + * @param out merged dim val + * @return status whether this operation success + */ +graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out); + +/** + * Merge two shapes + * @param s0 first shape val + * @param s1 second shape val + * @param out merged shape val + * @return status whether this operation success + */ +graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name); + +/** + * Merge two shapes + * @param s0 first Geshape val + * @param s1 second Geshape val + * @param out merged Geshape val + * @return status whether this operation success + */ +graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name); + +/** + * Replace one dim in a given shape + * @param s original shape + * @param dim_index_in dim index + * @param new_dim new dim value + * @param out new shape + * @return status whether this operation success + */ +graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name); + +/** + * Replace one dim in a given shape + * @param s original shape + * @param dim_index_in dim index + * @param new_dim new dim value + * @param out new shape + * @return status whether this operation success + */ +graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name); + +/** + * Check if it satisfies 0 <= index < limit + * @param index first input + * @param limit second input + * @return status whether this operation success + */ +template +bool FastBoundsCheck(const Ta index, const Tb limit); + +/** + * Add two dims + * @param dim0 first dim val + * @param dim1 second dim val + * @param out sum dim val + * @return status whether this operation success + */ +graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out); + +/** + * Subtract two dims + * @param dim0 first dim val + * @param dim1 second dim val + * @param out Subtract dim val + * @return status whether this operation success + */ +graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const GeShape& s, size_t start, size_t end, size_t stride, GeShape& out); + +/** + * Get SubShape according to start end index and step size stride + * @param s input Shape + * @param start sub start index + * @param end sub end index + * @param stride sub step size + * @param out sub shape output + * @return status whether this operation success + */ +graphStatus SubShape(const GeShape& s, int64_t start, int64_t end, int64_t stride, GeShape& out, const char* op_name); + +/** + * Concatenate two shape + * @param s1 first shape + * @param s2 second shape + * @param out concatenated shape + * @return status whether this operation success + */ +graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out); + +/** + * Concatenate two shape + * @param s1 first shape + * @param s2 second shape + * @param out concatenated shape + * @return status whether this operation success + */ +graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out); + +/** + * Gen matrix shape according d1 and d2 + * @param dim1 first dim val + * @param dim2 first dim val + * @param out matrix shape + * @return status whether this operation success + */ +graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out); + +/** + * Gen vector shape according d + * @param dim dim val + * @param out vector shape + * @return status whether this operation success + */ +graphStatus Vector(int64_t dim, Shape& out); + +/** + * Make shape from shape tensor + * @param tensor shape tensor + * @param out shape + * @return status whether this operation success + */ +graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name); + +/** + * Make shape from shape tensor + * @param op Operator + * @param dst_name const string & + * @param out GeShape + * @param op_name const char * + * @return status whether this operation success + */ +graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name); + +/** + * Make dim from scalar tensor + * @param tensor shape tensor + * @param out shape + * @return status whether this operation success + */ +graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name); + +/** + * Check whether Shape's rank is at most rank + * @param tensor input tensor + * @param rank expect val of Shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); + +/** + * Check whether Shape's rank is at most rank + * @param tensor input tensor + * @param rank expect val of Shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); + +/** + * make a empty dim shape + * @param out output Shape + * @return status whether Shape's condition Satisfied + */ +graphStatus Scalar(Shape& out); + +/** + * set input_name shape to output_name shape + * @param op Operator which need to infershape + * @param input_name input name of Operator + * @param output_name ouput name of Operator + * @return status whether infershape success + */ +graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name); + +/** + * Devide dim + * @param dividend + * @param divisor + * @param evenlyDivisible if to be divisible + * @param out dims + * @return status whether this operation success + */ +graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out, + const char* op_name); + +/** + * check shape fully defined or not + * @param shape Shape is checked + * @return whether shape is fully defined + */ +bool ShapeFullDefined(const Shape& shape); + +/** + * check shape fully defined or not + * @param shape Shape is checked + * @return whether shape is fully defined + */ +bool ShapeFullyDefined(const GeShape& shape); + +/** + * check shape known or not + * @param shape Shape is checked + * @return whether rank is known + */ +bool RankKnown(const Shape& shape); + +/** + * check ge_shape known or not + * @param shape GeShape is checked + * @return whether rank is known + */ +bool RankKnown(const GeShape& shape); + +/** + * make a unknown shape with rank + * @return unknown shape + */ +Shape UnknownShapeOfRank(int64_t rank); + +/** + * check dim value known or not + * @param shape which Shape need check dim value + * @param dimIndex the index of dim + * @return whether dim value is known + */ +bool ValueKnown(const Shape& shape, const size_t& dim_index); + +/** + * Validates the 3 component tensors of a sparse tensor + * have the proper shapes. + * @param sparse indices shape + * @param sparse values shape + * @param sparse shape + * @return status whether this operation success + */ +graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape, + const char* op_name); + +/** + * DecodeWavShapeFn, infereshape funtion of DecodeWav op + * @param op Operator + * @return status whether Shape's condition Satisfied + */ +graphStatus DecodeWavShapeFn(Operator& op); + +/** + * EncodeWavShapeFn, infereshape funtion of EncodeWav op + * @param op Operator + * @return status whether Shape's condition Satisfied + */ +graphStatus EncodeWavShapeFn(Operator& op); + +/** + * EncodeWavShapeFn, infereshape funtion of EncodeWav op + * @param op Operator + * @return status whether Shape's condition Satisfied + */ +graphStatus EncodeWavShapeFn(Operator& op); + +/** + * Infereshape funtion of SparseSegmentReduction op + * @param op Operator + * @return status whether Shape's condition Satisfied + */ +graphStatus SparseSegmentReductionShapeFn(Operator& op); + +/** + * Infereshape funtion of SparseSegmentReductionGrad op + * @param op Operator + * @return status whether Shape's condition Satisfied + */ +graphStatus SparseSegmentReductionGradShapeFn(Operator& op); + +/** + * Validates variable resource handle + * @param op Operator + * @param shape_and_type ShapeAndType vector + * @return status whether this operation success + */ +graphStatus ValidateVariableResourceHandle(Operator& op, std::vector& shape_and_type); + +/** + * Fill op_desc with input shape + * @param op_desc Operator desc ptr + * @param shape input tensor shape + * @param shape input tensor datatype + */ +void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type = DT_FLOAT); + +/** + * InferShapeErrorReport info + * @param op_name Operator name + * @param op_type Operator type + * @param value Operator value + * @param reason error reason + */ +void InferShapeErrorReport(const std::string& op_name, const std::string& op_type, + const std::string& value, const std::string& reason); + +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ diff --git a/tests/st/framework/stub_op_proto/util/error_code.h b/tests/st/framework/stub_op_proto/util/error_code.h new file mode 100644 index 00000000..ca4a1361 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/error_code.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file error_code.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ + +namespace ge { + +// error code for report purpose. +// 30000~34999 for aicpu engine error +// and 35000~39999 for infershape error of aicpu op +enum ViewErrorCode { + INVALID_INFER_SHAPE = 14001, + INVALID_INPUT_SHAPE = 35000, + INVALID_ATTR_VALUE = 35001, + INVALID_ATTR_SIZE = 35002, + OTHER_ERROR = 35003, + INVALID_CONV_ATTR_VALUE = 50029, + INVALID_CONV_SET_ATTR = 50057, + INVALID_CONV_SHAPE = 50058, + INVALID_MISS_INPUT = 70001, + INVALID_INPUT_FORMAT = 70002, + INVALID_INPUT_DTYPE = 70003, + INVALID_INPUT_TYPE = 70004, + INVALID_GET_ATTR = 70005, + INVALID_SET_ATTR = 70006, + INVALID_OPS_ATTR_VALUE = 70007, + FAILED_UPDATE_OP = 70008, + INVALID_SHAPE = 70009, + INVALID_SHAPE_SIZE = 70010, + INVALID_SHAPE_DIM = 70011, + INVALID_BROADCAST_SHAPE = 70012, + INVALID_TWO_INPUT_DTYPE = 70013, + INVALID_AIPP_ERROR = 70014, + INVALID_ONE_INPUT_SHAPE = 70015, + INVALID_TWO_INPUT_SHAPE = 70016, + INVALID_ONE_OUTPUT_SHAPE = 70017, + FAILED_GET_COMPILIE_PARAMS = 70018, +}; + +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ diff --git a/tests/st/framework/stub_op_proto/util/error_util.cc b/tests/st/framework/stub_op_proto/util/error_util.cc new file mode 100644 index 00000000..c055cfda --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/error_util.cc @@ -0,0 +1,318 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file error_util.cpp + * \brief + */ +#include +#include "common/util/error_manager/error_manager.h" +#include "error_util.h" +#include "error_code.h" +#include "op_log.h" + +using namespace std; +using namespace ge; + +namespace ge { + +inline static std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode) { + return "E" + std::to_string(errCode); +} + +void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, + const std::string& correct_shape) { + map err_map; + err_map["index"] = std::to_string(index); + err_map["opname"] = opname; + err_map["wrong_shape"] = wrong_shape; + err_map["correct_shape"] = correct_shape; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_SHAPE); + (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, + const std::string& correct_value) { + map err_map; + err_map["attrname"] = attrName; + err_map["opname"] = opname; + err_map["wrong_value"] = wrong_value; + err_map["correct_value"] = correct_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_VALUE); + (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, + const std::string& correct_size) { + map err_map; + err_map["attrname"] = attrName; + err_map["opname"] = opname; + err_map["wrong_size"] = wrong_size; + err_map["correct_size"] = correct_size; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_SIZE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg) { + map err_map; + err_map["opname"] = opname; + err_map["err_msg"] = err_msg; + string report_error_code = GetViewErrorCodeStr(ViewErrorCode::OTHER_ERROR); + (void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_MISS_INPUT); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_format_list, const std::string& data_format) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["expected_format_list"] = expected_format_list; + err_map["format"] = data_format; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_FORMAT); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_data_type_list, const std::string& data_type) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["expected_data_type_list"] = expected_data_type_list; + err_map["data_type"] = data_type; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_DTYPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, + const std::string& actual_type) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["param_type"] = param_type; + err_map["actual_type"] = actual_type; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_TYPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_GET_ATTR); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SET_ATTR); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, + const std::string& input_value) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["excepted_value"] = excepted_value; + err_map["input_value"] = input_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_OPS_ATTR_VALUE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_UPDATE_OP); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, + const std::string& param_value) { + map err_map; + err_map["op_name"] = op_name; + err_map["rule_desc"] = rule_desc; + err_map["param_name"] = param_name; + err_map["param_value"] = param_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& error_detail) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["error_detail"] = error_detail; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_INPUT_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, + const std::string& param_name2, const std::string& error_detail) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name1"] = param_name1; + err_map["param_name2"] = param_name2; + err_map["error_detail"] = error_detail; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& error_detail) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["error_detail"] = error_detail; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_OUTPUT_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_GET_COMPILIE_PARAMS); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, + const std::string& real_value) { + map err_map; + err_map["op_name"] = op_name; + err_map["input_name"] = input_name; + err_map["max_value"] = max_value; + err_map["real_value"] = real_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_SIZE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, + const std::string& min_value, const std::string& real_value) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["max_value"] = max_value; + err_map["min_value"] = min_value; + err_map["real_value"] = real_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_DIM); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, + const std::string& input2_name, const std::string& input1_shape, + const std::string& input2_shape) { + map err_map; + err_map["op_name"] = op_name; + err_map["input1_name"] = input1_name; + err_map["input2_name"] = input2_name; + err_map["input1_shape"] = input1_shape; + err_map["input2_shape"] = input2_shape; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_BROADCAST_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_dtype_list, const std::string& dtype) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["expected_dtype_list"] = expected_dtype_list; + err_map["dtype"] = dtype; + std::string report_error_code = "E50034"; + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, + const std::string& input2_name, const std::string& input1_dtype, + const std::string& input2_dtype) { + map err_map; + err_map["op_name"] = op_name; + err_map["input1_name"] = input1_name; + err_map["input2_name"] = input2_name; + err_map["input1_dtype"] = input1_dtype; + err_map["input2_dtype"] = input2_dtype; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_DTYPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, + const std::string& data_W) { + map err_map; + err_map["aipp_output_H"] = aipp_output_H; + err_map["aipp_output_W"] = aipp_output_W; + err_map["data_H"] = data_H; + err_map["data_W"] = data_W; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_AIPP_ERROR); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, + const std::string& input_value) { + map err_map; + err_map["op_name"] = op_name; + err_map["param_name"] = param_name; + err_map["expected_value"] = expected_value; + err_map["input_value"] = input_value; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_ATTR_VALUE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, + const std::string& param2_name) { + map err_map; + err_map["op_name"] = op_name; + err_map["param1_name"] = param1_name; + err_map["param2_name"] = param2_name; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SET_ATTR); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void OpsConvShapeErrReport(const std::string& op_name, const std::string& description) { + map err_map; + err_map["op_name"] = op_name; + err_map["description"] = description; + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SHAPE); + ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); +} + +void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, + const std::string& reason) { + std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INFER_SHAPE); + ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"}, + {op_name, op_type, value, reason}); +} + +void CommonRuntimeErrLog(const std::string& opname, const std::string& description){ + map err_map; + err_map["op_name"] = opname; + err_map["description"] = description; + OP_LOGE(opname.c_str(), description); + (void)ErrorManager::GetInstance().ReportErrMessage("E50058", err_map); +} + +} // namespace ge diff --git a/tests/st/framework/stub_op_proto/util/error_util.h b/tests/st/framework/stub_op_proto/util/error_util.h new file mode 100644 index 00000000..57dd087a --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/error_util.h @@ -0,0 +1,184 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file error_util.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ + +#include +#include +#include +#include "operator.h" + +namespace ge { + +/* + * get debug string of vector + * param[in] v vector + * return vector's debug string + */ +template +std::string DebugString(const std::vector& v) { + std::ostringstream oss; + oss << "["; + if (v.size() > 0) { + for (size_t i = 0; i < v.size() - 1; ++i) { + oss << v[i] << ", "; + } + oss << v[v.size() - 1]; + } + oss << "]"; + return oss.str(); +} + +/* + * str cat util function + * param[in] params need concat to string + * return concatted string + */ +template +std::string ConcatString(T arg) { + std::ostringstream oss; + oss << arg; + return oss.str(); +} + +template +std::string ConcatString(T arg, Ts... arg_left) { + std::ostringstream oss; + oss << arg; + oss << ConcatString(arg_left...); + return oss.str(); +} + +/* + * report input shape error of infer shape + * param[in] index the index of input + * param[in] opname op name + * param[in] wrong_shape wrong input shape + * param[in] correct_shape correct input shape + * return void + */ +void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, + const std::string& correct_shape); + +/* + * report attr value error of infer shape + * param[in] attrname the attr name + * param[in] opname op name + * param[in] wrong_value wrong attr value + * param[in] correct_value correct attr value + * return void + */ +void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, + const std::string& correct_value); + +/* + * report attr size error of infer shape + * param[in] attrname the attr name + * param[in] opname op name + * param[in] wrong_size wrong attr size + * param[in] correct_size correct attr size + * return void + */ +void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, + const std::string& correct_size); + +/* + * report common error of infer shape + * param[in] opname op name + * param[in] err_msg error message + * return void + */ +void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg); + +void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name); + +void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_format_list, const std::string& data_format); + +void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_data_type_list, const std::string& data_type); + +void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, + const std::string& actual_type); + +void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name); + +void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name); + +void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, + const std::string& input_value); + +void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name); + +void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, + const std::string& param_value); + +void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& error_detail); + +void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, + const std::string& param_name2, const std::string& error_detail); + +void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& error_detail); + +void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name); + +void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, + const std::string& real_value); + +void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, + const std::string& min_value, const std::string& real_value); + +void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, + const std::string& input2_name, const std::string& input1_shape, + const std::string& input2_shape); + +void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, + const std::string& expected_dtype_list, const std::string& dtype); + +void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, + const std::string& input2_name, const std::string& input1_dtype, + const std::string& input2_dtype); + +void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, + const std::string& data_W); + +void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, + const std::string& input_value); + +void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, + const std::string& param2_name); + +void OpsConvShapeErrReport(const std::string& op_name, const std::string& description); + +void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, + const std::string& reason); +/* + * log common runtime error + * param[in] opname op name + * param[in] error description + * return void + */ +void CommonRuntimeErrLog(const std::string& opname, const std::string& description); +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ diff --git a/tests/st/framework/stub_op_proto/util/op_common_util.h b/tests/st/framework/stub_op_proto/util/op_common_util.h new file mode 100644 index 00000000..c22322f6 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/op_common_util.h @@ -0,0 +1,73 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file op_common_util.h + * \brief common util for op, in this file only original type or class in C++ allowed + */ + +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ + +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::pair& values) { + os << "[" << values.first << ", " << values.second << "]"; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const std::vector& values) { + os << "["; + for (const auto& item : values) { + os << item << ", "; + } + os << "]"; + return os; +} + +namespace ops { +template +std::string to_string(const std::vector &items) { + std::ostringstream oss; + oss << "["; + for (const auto &item: items) { + oss << item << ", "; + } + oss << "]"; + return oss.str(); +} + +template +std::string to_string(const std::set &items) { + std::ostringstream oss; + oss << "["; + for (const auto &item: items) { + oss << item << ", "; + } + oss << "]"; + return oss.str(); +} + +} // namespace ops + + +#endif //OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ diff --git a/tests/st/framework/stub_op_proto/util/op_log.h b/tests/st/framework/stub_op_proto/util/op_log.h new file mode 100644 index 00000000..76539112 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/op_log.h @@ -0,0 +1,89 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file op_log.h + * \brief + */ +#ifndef GE_OP_LOG_H +#define GE_OP_LOG_H + +#if !defined( __ANDROID__) && !defined(ANDROID) +#include "toolchain/slog.h" +#else +#include +#endif + +#define OPPROTO_SUBMOD_NAME "OP_PROTO" + +#if !defined( __ANDROID__) && !defined(ANDROID) +#define OP_LOGI(opname, ...) D_OP_LOGI(opname, __VA_ARGS__) +#define OP_LOGW(opname, ...) D_OP_LOGW(opname, __VA_ARGS__) +#define OP_LOGE(opname, ...) D_OP_LOGE(opname, __VA_ARGS__) +#define OP_LOGD(opname, ...) D_OP_LOGD(opname, __VA_ARGS__) +#define GE_OP_LOGI(opname, ...) GE_D_OP_LOGI(opname, __VA_ARGS__) +#define GE_OP_LOGW(opname, ...) GE_D_OP_LOGW(opname, __VA_ARGS__) +#define GE_OP_LOGE(opname, ...) GE_D_OP_LOGE(opname, __VA_ARGS__) +#define GE_OP_LOGD(opname, ...) GE_D_OP_LOGD(opname, __VA_ARGS__) +#define FUSION_PASS_LOGI(...) D_FUSION_PASS_LOGI(__VA_ARGS__) +#define FUSION_PASS_LOGW(...) D_FUSION_PASS_LOGW(__VA_ARGS__) +#define FUSION_PASS_LOGE(...) D_FUSION_PASS_LOGE(__VA_ARGS__) +#define FUSION_PASS_LOGD(...) D_FUSION_PASS_LOGD(__VA_ARGS__) +#else +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) +#define OP_LOGE(opname, ...) +#define OP_LOGD(opname, ...) +#define FUSION_PASS_LOGI(...) +#define FUSION_PASS_LOGW(...) +#define FUSION_PASS_LOGE(...) +#define FUSION_PASS_LOGD(...) +#endif + +#if !defined( __ANDROID__) && !defined(ANDROID) +#define D_OP_LOGI(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define D_OP_LOGW(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define D_OP_LOGE(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define D_OP_LOGD(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define GE_D_OP_LOGI(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define GE_D_OP_LOGW(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define GE_D_OP_LOGE(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define GE_D_OP_LOGD(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) +#define D_FUSION_PASS_LOGI(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#define D_FUSION_PASS_LOGW(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#define D_FUSION_PASS_LOGE(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#define D_FUSION_PASS_LOGD(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#else +#define D_OP_LOGI(opname, fmt, ...) +#define D_OP_LOGW(opname, fmt, ...) +#define D_OP_LOGE(opname, fmt, ...) +#define D_OP_LOGD(opname, fmt, ...) +#define D_FUSION_PASS_LOGI(fmt, ...) +#define D_FUSION_PASS_LOGW(fmt, ...) +#define D_FUSION_PASS_LOGE(fmt, ...) +#define D_FUSION_PASS_LOGD(fmt, ...) +#endif + +#define OP_CHECK(condition, log_func, do_expr) \ + static_assert(std::is_same::type>::value, "condition should be bool"); \ + do { \ + if (condition) { \ + log_func; \ + do_expr; \ + } \ + } while (0) + +#endif //GE_OP_LOG_H diff --git a/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc b/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file mode 100644 index 00000000..fd235843 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc @@ -0,0 +1,258 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file transfer_shape_according_to_format.cpp + * \brief set shape according to original format and current format + */ +#include "transfer_shape_according_to_format.h" +#include "framework/omg/omg_inner_types.h" + +namespace ge { +ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { + getNewShapeFuncMap = { + {ge::FORMAT_NCHW, std::make_shared(GetNCHWShapeByAxisValue)}, + {ge::FORMAT_NHWC, std::make_shared(GetNHWCShapeByAxisValue)}, + {ge::FORMAT_NC1HWC0, std::make_shared(GetNC1HWC0ShapeByAxisValue)}, + {ge::FORMAT_FRACTAL_Z, std::make_shared(GetFzShapeByAxisValue)}, + {ge::FORMAT_HWCN, std::make_shared(GetHWCNShapeByAxisValue)}, + {ge::FORMAT_C1HWNCoC0, std::make_shared(GetC1HWNCoC0ShapeByAxisValue)}, + {ge::FORMAT_FRACTAL_NZ, std::make_shared(GetNzShapeByAxisValue)}}; + + mapOfDtypeAndC0 = { + {ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, + {ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, + {ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, + {ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; +} + +bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_C]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newShape = ge::GeShape(newDimVec); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newDimVec.push_back(axisValue[AXIS_C]); + newShape = ge::GeShape(newDimVec); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { + CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_C1]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newDimVec.push_back(axisValue[AXIS_C0]); + newShape = ge::GeShape(newDimVec); + } else { + CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_C]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newShape = ge::GeShape(newDimVec); + } + return true; +} + +bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + + if (ndValue.size() == SIZE_OF_CN) { + CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); + auto sizeOfOriginalVec = ndValue.size(); + std::vector newDimVec = ndValue; + /* sizeOfOriginalVec - 1 mean the last value of original vec + * sizeOfOriginalVec - 2 mean the second last value of original vec */ + newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); + newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); + newDimVec.push_back(SHAPE_NUMBER_16); + newDimVec.push_back(axisValue[AXIS_C0]); + newShape = ge::GeShape(newDimVec); + } else { + if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { + CHECK(axisValue.size() <= AXIS_C1, LOG_INFO("AxisValue is not correct!"), return true); + int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; + newDimVec.push_back(hwc1); + newDimVec.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); + newDimVec.push_back(NI); + newDimVec.push_back(axisValue[AXIS_C0]); + newShape = ge::GeShape(newDimVec); + } else { + CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_C]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newShape = ge::GeShape(newDimVec); + } + } + + return true; +} + +bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newDimVec.push_back(axisValue[AXIS_C]); + newDimVec.push_back(axisValue[AXIS_N]); + newShape = ge::GeShape(newDimVec); + return true; +} + +bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(axisValue.size() <= AXIS_Co, LOG_INFO("AxisValue is not correct!"), return true); + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec; + newDimVec.push_back(axisValue[AXIS_C1]); + newDimVec.push_back(axisValue[AXIS_H]); + newDimVec.push_back(axisValue[AXIS_W]); + newDimVec.push_back(axisValue[AXIS_N]); + newDimVec.push_back(axisValue[AXIS_Co]); + newDimVec.push_back(axisValue[AXIS_C0]); + newShape = ge::GeShape(newDimVec); + return true; +} + +bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, + const vector& ndValue) { + CHECK(ndValue.empty(), LOG_INFO("ndValue is empty!"), return true); + CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, + LOG_INFO("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); + uint32_t sizeOfOriginalVec = ndValue.size(); + if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { + LOG_INFO("ndValue's dim num is less than 2!"); + return true; + } + /* axisValue is initialized as a size 6 vector. */ + std::vector newDimVec = ndValue; + + /* sizeOfOriginalVec - 1 mean the last value of original vec + * sizeOfOriginalVec - 2 mean the second last value of original vec */ + newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); + + newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = + DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); + newDimVec.push_back(SHAPE_NUMBER_16); + newDimVec.push_back(axisValue[AXIS_C0]); + newShape = ge::GeShape(newDimVec); + return true; +} + +bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { + /* The default new shape is old shape */ + shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; + if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { + LOG_ERROR("Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); + return false; + } + + if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { + LOG_ERROR("currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); + return false; + } + AxisUtil* axisutil_object = new AxisUtil(); + if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { + delete axisutil_object; + return true; + } + + auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); + if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { + LOG_INFO("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); + delete axisutil_object; + return true; + } + LOG_INFO("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); + GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; + CHECK_NOTNULL(getNewShapeFunc); + std::vector axisValue; + for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { + axisValue.push_back(1); + } + std::vector ndValue; + uint32_t c0; + if (mapOfDtypeAndC0.empty()) { + c0 = SHAPE_NUMBER_16; + } else { + auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); + if (iterGetC0 == mapOfDtypeAndC0.end()) { + LOG_ERROR("Dtype is not support."); + delete axisutil_object; + return true; + } + c0 = iterGetC0->second; + } + + // The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 + if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { + c0 = SHAPE_DIM_VALUE_C04; + } + + bool status = axisutil_object->GetAxisValueByOriginFormat( + shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape.GetDims(), c0, axisValue, ndValue); + if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { + delete axisutil_object; + return true; + } + delete axisutil_object; + + (*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); + if (c != nullptr) { + *c = axisValue[AXIS_C]; + } + return true; +} +}; // namespace ge diff --git a/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h b/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file mode 100644 index 00000000..a97bcea8 --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file transfer_shape_according_to_format.h + * \brief set shape according to original format and current format + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ + +#include "axis_util.h" + +#include +#include +#include + +#include "framework/omg/omg_inner_types.h" +#include "operator.h" +#include "graph/operator_reg.h" +#include "graph/tensor.h" +#include "graph/utils/op_desc_utils.h" + +#include "op_log.h" + +#define LOG_ERROR(format, args...) printf(format, ##args) +#define LOG_INFO(format, args...) printf(format, ##args) + +namespace ge { + +enum OpImplType { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_RESERVED // reserved value +}; + +const uint32_t SHAPE_NUMBER_16 = 16; +const uint32_t SHAPE_NUMBER_32 = 32; +const uint32_t SHAPE_DIM_VALUE_C04 = 4; +const uint32_t NI = 16; +const uint32_t MINUS_VALUE_ONE = 1; +const uint32_t MINUS_VALUE_TWO = 2; +const uint32_t SIZE_OF_CN = 2; +const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; + +/* The first parameter is axis value, second is new shape and third is + * op implementation type. */ +using GetNewShapeByAxisValueAndFormat = + std::function&, vector&)>; + +using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr; + +struct ShapeAndFormatInfo { + const ge::GeShape& oldShape; + ge::GeShape& newShape; + const ge::Format& oldFormat; + const ge::Format& newFormat; + const ge::DataType& currentDataType; + const int64_t& opImplType; +}; + +using ShapeAndFormat = struct ShapeAndFormatInfo; + +class ShapeTransferAccordingToFormat { + public: + ShapeTransferAccordingToFormat(); + + ~ShapeTransferAccordingToFormat(){}; + + ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; + + ShapeTransferAccordingToFormat& operator=(const ShapeTransferAccordingToFormat&) = delete; + + bool GetShapeAccordingToFormat(ShapeAndFormat& inputAndOutputInfo, int64_t* c = nullptr); + + /* ----------Below is the function of getting new shape---------------------- */ + static bool GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector& axisValue, + const vector& ndValue); + + static bool GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector& axisValue, + const vector& ndValue); + + static bool GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, const vector& ndValue); + + static bool GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector& axisValue, + const vector& ndValue); + + static bool GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector& axisValue, + const vector& ndValue); + + static bool GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, + const vector& axisValue, const vector& ndValue); + + static bool GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector& axisValue, + const vector& ndValue); + + private: + /* map of GetAxisValueInfoByFormat, get axis value by different original + * formats. */ + std::map getNewShapeFuncMap; + std::map mapOfDtypeAndC0; +}; + +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ diff --git a/tests/st/framework/stub_op_proto/util/util.cc b/tests/st/framework/stub_op_proto/util/util.cc new file mode 100644 index 00000000..966dd01c --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/util.cc @@ -0,0 +1,1097 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file util.cpp + * \brief + */ +#include "util.h" +#include +#include +#include +#include +#include +#include "./error_util.h" +#include "op_common_util.h" +#include "graph/utils/type_utils.h" +#include "axis_util.h" + +namespace ge { + +bool GetInputDataType(const ge::DataType& data_type, const std::vector& supportList) { + std::vector::const_iterator supportIter = find(supportList.begin(), supportList.end(), data_type); + if (supportIter == supportList.end()) { + return false; + } + return true; +} + +bool CheckInputDtypeAndShape(const Operator& op, const std::map>& inputTensorMap) { + auto iter = inputTensorMap.begin(); + auto first_name = iter->first; + auto first_shape_dims = op.GetInputDesc(iter->first).GetShape().GetDims(); + auto first_input_dtype = op.GetInputDesc(iter->first).GetDataType(); + for (; iter != inputTensorMap.end(); ++iter) { + const TensorDesc input_desc = op.GetInputDesc(iter->first); + // check input dtype + auto input_type = input_desc.GetDataType(); + if (input_type != first_input_dtype) { + OP_LOGE(op.GetName().c_str(), "the op type of param %s must equal with param %s", iter->first.c_str(), + first_name.c_str()); + return false; + } + auto dims = input_desc.GetShape().GetDims(); + if (dims != first_shape_dims) { + OP_LOGE(op.GetName().c_str(), "the op shape of param %s must equal with param %s", iter->first.c_str(), + first_name.c_str()); + return false; + } + } + return true; +} + +bool CheckInputDataType(const Operator& op, const std::string& input_name, + const std::vector& support_list) { + bool valid = false; + DataType input_type = op.GetInputDesc(input_name).GetDataType(); + do { + const auto& found_list = find(support_list.begin(), support_list.end(), input_type); + + if (found_list == support_list.end()) { + break; + } + + const auto& found_map = DTYPE_STR_MAP.find(input_type); + if (found_map == DTYPE_STR_MAP.end()) { + break; + } + + valid = true; + } while (0); + + if (!valid) { + OpsInputDtypeErrReport(op.GetName(), input_name, DebugString(support_list), ConcatString(input_type)); + OP_LOGE(op.GetName().c_str(), "The op do not support the dtype %s", + ge::TypeUtils::DataTypeToSerialString(input_type).c_str()); + return false; + } + + return true; +} + +bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr || + op_desc->MutableInputDesc(input_name2) == nullptr, + OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + + DataType input_type_x1 = op_desc->MutableInputDesc(input_name1)->GetDataType(); + DataType input_type_x2 = op_desc->MutableInputDesc(input_name2)->GetDataType(); + if (input_type_x1 != input_type_x2) { + OpsTwoInputDtypeErrReport(op.GetName(), input_name1, input_name2, ConcatString(input_type_x1), + ConcatString(input_type_x2)); + OP_LOGE(op.GetName().c_str(), "The %s op dtype is not same, type1:%s, type2:%s", op.GetName().c_str(), + ge::TypeUtils::DataTypeToSerialString(input_type_x1).c_str(), + ge::TypeUtils::DataTypeToSerialString(input_type_x2).c_str()); + return false; + } + + return true; +} + +bool CheckInputDtypeSame(const Operator& op, std::vector& input_tensors) { + auto first_name = input_tensors.begin(); + auto first_input_dtype = op.GetInputDesc(*first_name).GetDataType(); + for (const string& input_name : input_tensors) { + const TensorDesc input_desc = op.GetInputDesc(input_name); + auto input_dtype = input_desc.GetDataType(); + if (input_dtype != first_input_dtype) { + OP_LOGE(op.GetName().c_str(), "the op type of param %s must equal with param %s", input_name.c_str(), + (*first_name).c_str()); + return false; + } + } + return true; +} + +bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector& input_names) { + auto first_input_name = input_names.begin(); + auto first_input_des = op.GetInputDesc(*first_input_name); + auto input_name = first_input_name; + for (++input_name; input_name != input_names.end(); ++input_name) { + auto input_des = op.GetInputDesc(*first_input_name); + + if (input_des.GetDataType() != first_input_des.GetDataType() || + input_des.GetShape().GetDims() != first_input_des.GetShape().GetDims()) { + OpsAttrValueErrReport( + op.GetName(), ConcatString(input_name->c_str(), "'s dtype and shape"), + ConcatString("same as", first_input_name->c_str(), "[", first_input_des.GetDataType(), "]", "[", + DebugString(first_input_des.GetShape().GetDims()), "]"), + ConcatString("[", input_des.GetDataType(), "]", "[", DebugString(input_des.GetShape().GetDims()), "]")); + OP_LOGE(op.GetName().c_str(), "the dtype and shape of param %s must be same as param %s", + first_input_name->c_str(), input_name->c_str()); + return false; + } + } + + return true; +} + +bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name, bool& is_dynamic) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr || op_desc->MutableOutputDesc(output_name) == nullptr|| + op_desc->MutableInputDesc(input_name1) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr, + OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + + DataType input_dtype = op_desc->MutableInputDesc(input_name1)->GetDataType(); + + // output Desc + GeTensorDescPtr tensordesc_output = op_desc->MutableOutputDesc(output_name); + tensordesc_output->SetDataType(input_dtype); + + ge::GeShape shapeX = op_desc->MutableInputDesc(input_name1)->GetShape(); + ge::GeShape shapeY = op_desc->MutableInputDesc(input_name2)->GetShape(); + OP_LOGI(op.GetName().c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(), + input_name2.c_str(), to_string(shapeY).c_str()); + std::vector dimsX = shapeX.GetDims(); + std::vector dimsY = shapeY.GetDims(); + // swap based on shape size + if (dimsX.size() < dimsY.size()) { + std::vector dimsTmp = dimsX; + dimsX = dimsY; + dimsY = dimsTmp; + } + + std::vector dimVec; + // unknown rank + if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) { + tensordesc_output->SetShape(ge::GeShape(UNKNOWN_RANK)); + OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%d.", to_string(ge::Shape(UNKNOWN_RANK)).c_str(), + input_dtype); + is_dynamic = false; + return true; + } + + // pad 1 for small shape + if (dimsX.size() != dimsY.size()) { + int dec = dimsX.size() - dimsY.size(); + for (int i = 0; i < dec; i++) { + dimsY.insert(dimsY.begin(), (int64_t)1); + } + } + + // when not dynamic case, do infer shape only + if (!IsUnknown(dimsY) && !IsUnknown(dimsX)) { + for (size_t i = 0; i < dimsX.size(); i++) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dims = (dimsY[i] == 0 || dimsX[i] == 0) ? 0 : dims; + dimVec.push_back(dims); + } + tensordesc_output->SetShape(ge::GeShape(dimVec)); + is_dynamic = false; + return true; + } + + // dynamic case + for (size_t i = 0; i < dimsX.size(); i++) { + CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1), + OpsInputShapeBroadcastErrReport(op.GetName(), input_name1, input_name2, ConcatString(dimsX[i]), + ConcatString(dimsY[i])); + OP_LOGE(op.GetName().c_str(), "The %s's dimensions does not match the broadcast rule(%lu %lu).", + op.GetName().c_str(), dimsX[i], dimsY[i]), + return false); + + if ((dimsX[i] == -1) && (dimsY[i] != -1)) { + if (dimsY[i] > 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } else if (dimsY[i] == 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) { + dimVec.push_back(0); + } + } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) { + if (dimsX[i] > 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } else if (dimsX[i] == 0) { + dimVec.push_back(0); + } else if (dimsX[i] == 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } + } else { + if ((dimsX[i] == -1) && (dimsY[i] == -1)) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } else { + if (dimsY[i] == 0 || dimsX[i] == 0) { + dimVec.push_back(0); + } else { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } + } + } + } + ge::GeShape outputShape = ge::GeShape(dimVec); + tensordesc_output->SetShape(outputShape); + + OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(), + ge::TypeUtils::DataTypeToSerialString(input_dtype).c_str()); + is_dynamic = IsUnknown(dimVec); + + if (is_dynamic) { + if (!InferShapeRangeTwoInOneOutBroadcase(op, input_name1, input_name2, output_name)) { + return false; + } + } + return true; +} + +bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr || + op_desc->MutableOutputDesc(output_name) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr, + OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + + DataType input_dtype = op_desc->MutableInputDesc(input_name1)->GetDataType(); + + GeTensorDescPtr tensordesc_output = op_desc->MutableOutputDesc(output_name); + + ge::GeShape shapeX = op_desc->MutableInputDesc(input_name1)->GetShape(); + ge::GeShape shapeY = op_desc->MutableInputDesc(input_name2)->GetShape(); + OP_LOGI(op.GetName().c_str(), "shape %s: %s, shape %s: %s.", input_name1.c_str(), to_string(shapeX).c_str(), + input_name2.c_str(), to_string(shapeY).c_str()); + std::vector dimsX = shapeX.GetDims(); + std::vector dimsY = shapeY.GetDims(); + // swap based on shape size + if (dimsX.size() < dimsY.size()) { + std::vector dimsTmp = dimsX; + dimsX = dimsY; + dimsY = dimsTmp; + } + + std::vector dimVec; + + // unknown rank + if (IsUnknownRankShape(dimsX) || IsUnknownRankShape(dimsY)) { + tensordesc_output->SetShape(ge::GeShape(UNKNOWN_RANK)); + tensordesc_output->SetDataType(input_dtype); + OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%d.", to_string(ge::Shape(UNKNOWN_RANK)).c_str(), + input_dtype); + return true; + } + + // pad 1 for small shape + if (dimsX.size() != dimsY.size()) { + int dec = dimsX.size() - dimsY.size(); + for (int i = 0; i < dec; i++) { + dimsY.insert(dimsY.begin(), (int64_t)1); + } + } + + for (size_t i = 0; i < dimsX.size(); i++) { + CHECK((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1) && (dimsX[i] != -1) && (dimsY[i] != -1), + OpsInputShapeBroadcastErrReport(op.GetName(), input_name1, input_name2, ConcatString(dimsX[i]), + ConcatString(dimsY[i])); + OP_LOGE(op.GetName().c_str(), "The %s's dimensions does not match the broadcast rule(%lu %lu).", + op.GetName().c_str(), dimsX[i], dimsY[i]), + return false); + + if ((dimsX[i] == -1) && (dimsY[i] != -1)) { + if (dimsY[i] > 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } else if (dimsY[i] == 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } else if ((dimsY[i] == 0) || (dimsX[i] == 0)) { + dimVec.push_back(0); + } + } else if ((dimsX[i] != -1) && (dimsY[i] == -1)) { + if (dimsX[i] > 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } else if (dimsX[i] == 0) { + dimVec.push_back(0); + } else if (dimsX[i] == 1) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } + } else { + if ((dimsX[i] == -1) && (dimsY[i] == -1)) { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + dimVec[i] = -1; + } else { + if (dimsY[i] == 0 || dimsX[i] == 0) { + dimVec.push_back(0); + } else { + int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; + dimVec.push_back(dims); + } + } + } + } + ge::GeShape outputShape = ge::GeShape(dimVec); + + tensordesc_output->SetShape(outputShape); + tensordesc_output->SetDataType(input_dtype); + OP_LOGI(op.GetName().c_str(), "output shape is: %s, output dtype is:%s.", to_string(outputShape).c_str(), + ge::TypeUtils::DataTypeToSerialString(input_dtype).c_str()); + + return true; +} + +bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr || op_desc->MutableInputDesc(input_name1) == nullptr || + op_desc->MutableOutputDesc(output_name) == nullptr || op_desc->MutableInputDesc(input_name2) == nullptr, + OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + + ge::GeShape shape_x = op_desc->MutableInputDesc(input_name1)->GetShape(); + ge::GeShape shape_y = op_desc->MutableInputDesc(input_name2)->GetShape(); + + std::vector dims_x = shape_x.GetDims(); + std::vector dims_y = shape_y.GetDims(); + + std::vector> shape_range_x; + op_desc->MutableInputDesc(input_name1)->GetShapeRange(shape_range_x); + std::vector> shape_range_y; + op_desc->MutableInputDesc(input_name2)->GetShapeRange(shape_range_y); + + MakeUpShapeRange(dims_x, shape_range_x); + MakeUpShapeRange(dims_y, shape_range_y); + + ge::GeShape shape_out = op_desc->MutableOutputDesc(output_name)->GetShape(); + std::vector dims_out = shape_out.GetDims(); + size_t size_shape_out = dims_out.size(); + + std::vector> out_range; + + if (!IsUnknownRankShape(dims_out)) { + // shape switch by shape dim size + if (dims_x.size() < dims_y.size()) { + std::vector dims_tmp = dims_x; + dims_x = dims_y; + dims_y = dims_tmp; + + std::vector> range_temp = shape_range_x; + shape_range_x = shape_range_y; + shape_range_y = range_temp; + } + + while (dims_x.size() > shape_range_y.size()) { + shape_range_y.insert(shape_range_y.begin(), std::pair(1, 1)); + } + + for (size_t i = 0; i < size_shape_out; i++) { + if (dims_out[i] != -1) { + out_range.push_back(std::pair(dims_out[i], dims_out[i])); + continue; + } + if (i < shape_range_x.size() && i < shape_range_y.size()) { + if (shape_range_x[i].second == -1 && shape_range_y[i].second == 1) { + out_range.push_back(std::pair(1, -1)); + } else if (shape_range_x[i].second == 1 && shape_range_y[i].second == -1) { + out_range.push_back(std::pair(1, -1)); + } else if (shape_range_x[i].first == 1 || shape_range_y[i].first == 1) { + // one shape size maybe 1, so will support boardcast + // first_range == max first + int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first); + int64_t second_range = shape_range_x[i].first == 1 ? shape_range_y[i].second : shape_range_x[i].second; + if (shape_range_x[i].first == 1 && shape_range_y[i].first == 1) { + second_range = std::max(shape_range_x[i].second, shape_range_y[i].second); + second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1) ? -1 : second_range; + } + out_range.push_back(std::pair(first_range, second_range)); + } else { + // no 1 in range.first, mean no boardcast for range + // get intersect range + int64_t first_range = std::max(shape_range_x[i].first, shape_range_y[i].first); + int64_t second_range = std::min(shape_range_x[i].second, shape_range_y[i].second); + second_range = (shape_range_x[i].second == -1 || shape_range_y[i].second == -1) + ? std::max(shape_range_x[i].second, shape_range_y[i].second) + : second_range; + out_range.push_back(std::pair(first_range, second_range)); + } + } + } + } + + GeTensorDescPtr tensor_out = op_desc->MutableOutputDesc(output_name); + tensor_out->SetShapeRange(out_range); + + return true; +} + +bool GetInputDataType(const ge::DataType& dataType, const std::vector& supportList, std::string& dType) { + std::vector::const_iterator supportIter = find(supportList.begin(), supportList.end(), dataType); + if (supportIter == supportList.end()) { + return false; + } + + std::map::const_iterator totalIter = DTYPE_STR_MAP.find(dataType); + if (totalIter == DTYPE_STR_MAP.end()) { + return false; + } + + dType = totalIter->second; + return true; +} + +bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name, + const std::vector& supportList) { + DataType input_type = op.GetInputDesc(input_name).GetDataType(); + if (false == GetInputDataType(input_type, supportList, *data_type)) { + LOG_ERROR("[ERROR]op [%s] [%s] do not supported dtype [%s]!\n", op.GetName().c_str(), input_name.c_str(), + data_type->c_str()); + return false; + } + return true; +} + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value) { + if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) { + LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str()); + return false; + } + return true; +} + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value) { + if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) { + LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str()); + return false; + } + return true; +} + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value) { + if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) { + LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str()); + return false; + } + return true; +} + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector& attr_value) { + if (ge::GRAPH_SUCCESS != op.GetAttr(key_name, attr_value)) { + LOG_ERROR("[ERROR]op [%s] GetOpAttr [%s] failed!\n", op.GetName().c_str(), key_name.c_str()); + return false; + } + return true; +} + +template +static std::vector GetConstIntData(const uint8_t* const_data, size_t data_size) { + size_t size = data_size / sizeof(T); + std::vector result(size); + T* data = (T*)const_data; + for (size_t i = 0; i < size; i++) { + result[i] = *(data + i); + } + + return result; +} + +bool GetConstIntData(const Tensor& data, DataType data_type, std::vector& const_values) { + using namespace std::placeholders; + const std::map(const uint8_t*, size_t)>> type_call_map = { + {DT_INT8, std::bind(GetConstIntData, _1, _2)}, + {DT_INT16, std::bind(GetConstIntData, _1, _2)}, + {DT_INT32, std::bind(GetConstIntData, _1, _2)}, + {DT_INT64, std::bind(GetConstIntData, _1, _2)}, + }; + + auto found = type_call_map.find(data_type); + if (found == type_call_map.end()) { + USER_GE_LOGE("[ERROR]GetConstIntData is not support data_type[%s]!", + ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return false; + } + + const_values = found->second(data.GetData(), data.GetSize()); + + return true; +} + +bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, + std::vector& const_data) { + size_t size = 0; + CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64, + OP_LOGE(op.GetName().c_str(), "not support this type"), return false); + if (dtype == ge::DT_INT32) { + int32_t* const_data_ptr = (int32_t*)const_tensor.GetData(); + size = const_tensor.GetSize() / sizeof(int32_t); + for (size_t i = 0; i < size; ++i) { + const_data.push_back((int32_t)((*(const_data_ptr + i)))); + OP_LOGD(op.GetName().c_str(), "const data int32 fusion pass ====== %d", (int32_t)(*(const_data_ptr + i))); + } + } else if (dtype == ge::DT_INT64) { + int64_t* const_data_ptr = (int64_t*)const_tensor.GetData(); + size = const_tensor.GetSize() / sizeof(int64_t); + for (size_t i = 0; i < size; ++i) { + const_data.push_back(((int64_t)(*(const_data_ptr + i)))); + OP_LOGD(op.GetName().c_str(), "const data int64 fusion pass ====== %d", (int64_t)(*(const_data_ptr + i))); + } + } + return true; +} + +bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, + const DataType& dtype, std::vector& const_data) { + size_t size = const_tensor->GetData().GetSize(); + void* data_ptr = (void*)const_tensor->GetData().GetData(); + CHECK(data_ptr == nullptr, OP_LOGE(op.GetName().c_str(), "data is null."), return false); + + CHECK(dtype != ge::DT_INT32 && dtype != ge::DT_INT64, + OP_LOGE(op.GetName().c_str(), "const not support this type"), return false); + if (dtype == ge::DT_INT32){ + int32_t* const_data_ptr = reinterpret_cast(data_ptr); + size = size / sizeof(int32_t); + for (size_t i=0; i < size; i++) { + const_data.push_back((int64_t)((int32_t) ((*(const_data_ptr + i))))); + } + } else if (dtype == ge::DT_INT64) { + int64_t* const_data_ptr = reinterpret_cast(data_ptr); + size = size / sizeof(int64_t); + for (size_t i=0; i < size; i++) { + const_data.push_back((int64_t)((int64_t) ((*(const_data_ptr + i))))); + } + } + return true; +} + +bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data) { + if (dtype == ge::DT_INT32) { + int32_t* const_data_ptr = (int32_t*)const_tensor.GetData(); + const_data = (int32_t)(*const_data_ptr); + } else if (dtype == ge::DT_INT64) { + int64_t* const_data_ptr = (int64_t*)const_tensor.GetData(); + const_data = (int64_t)(*const_data_ptr); + } else { + OP_LOGE(op.GetName().c_str(), "not support this type"); + return false; + } + return true; +} + +string to_string(const vector& shape) { + return ops::to_string(shape); +} + +std::string to_string(const ge::Shape& shape) { + return to_string(shape.GetDims()); +} + +std::string to_string(const ge::GeShape& shape) { + return to_string(shape.GetDims()); +} + +std::string to_string(const vector>& ranges) { + return ops::to_string(ranges); +} + +bool DynamicShapeInfer::CatchFormatAndShape() { + inputs = op_desc->GetAllInputName(); + outputs = op_desc->GetAllOutputName(); + GeTensorDescPtr tensor_desc_input, tensor_desc_output; + + // get and save current input shape&format, and assign origin ones to them + std::string input_name; + for (map::iterator it = inputs.begin(); it != inputs.end(); it++) { + input_name = it->first; + tensor_desc_input = op_desc->MutableInputDesc(input_name); + if (tensor_desc_input == nullptr) { + continue; + } + Format curr_format = tensor_desc_input->GetFormat(); + + map_format.insert(std::pair(input_name, curr_format)); + map_dtype.insert(std::pair(input_name, tensor_desc_input->GetDataType())); + + if (tensor_desc_input->GetOriginFormat() == curr_format) { + continue; + } + tensor_desc_input->SetFormat(tensor_desc_input->GetOriginFormat()); + tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape()); + } + + // get and save current output shape&format, and assign origin ones to them + std::string output_name; + for (map::iterator it = outputs.begin(); it != outputs.end(); it++) { + output_name = it->first; + tensor_desc_output = op_desc->MutableOutputDesc(output_name); + if (tensor_desc_output == nullptr) { + continue; + } + Format curr_format = tensor_desc_output->GetFormat(); + + map_format.insert(std::pair(output_name, curr_format)); + map_dtype.insert(std::pair(output_name, tensor_desc_output->GetDataType())); + + if (tensor_desc_output->GetOriginFormat() == curr_format) { + continue; + } + tensor_desc_output->SetFormat(tensor_desc_output->GetOriginFormat()); + } + + return true; +} + +bool DynamicShapeInfer::UpdateFormatAndShape() { + const int64_t opImplType = EN_IMPL_CUSTOM_TBE; + GeTensorDescPtr tensor_desc_input, tensor_desc_output; + // assign output's after infershape to origin shape + for (map::iterator it = outputs.begin(); it != outputs.end(); it++) { + tensor_desc_output = op_desc->MutableOutputDesc(it->first); + if (tensor_desc_output == nullptr) { + continue; + } + tensor_desc_output->SetOriginShape(tensor_desc_output->GetShape()); + } + + // transfer input's origin shape to current shape + Format ori_input_format, cur_input_format; + GeShape ori_infer_shape, current_shape; + std::string input_name; + for (map::iterator it = inputs.begin(); it != inputs.end(); it++) { + input_name = it->first; + tensor_desc_input = op_desc->MutableInputDesc(input_name); + if (tensor_desc_input == nullptr) { + continue; + } + ori_input_format = tensor_desc_input->GetFormat(); + ori_infer_shape = tensor_desc_input->GetShape(); + cur_input_format = map_format[input_name]; + + // print some info + OP_LOGI(op.GetName().c_str(), "origin input shape %s is %s", input_name.c_str(), + to_string(ori_infer_shape).c_str()); + + ShapeAndFormat shapeAndFormatInfoInput = {ori_infer_shape, current_shape, ori_input_format, + cur_input_format, map_dtype[input_name], opImplType}; + if (ori_input_format == cur_input_format) { + // no need to transfer shape + continue; + } else { + ShapeTransferAccordingToFormat* global_object = new ShapeTransferAccordingToFormat(); + CHECK(global_object == nullptr, OP_LOGE(op.GetName().c_str(), "new ShapeTransferAccordingToFormat failed."), + return false); + global_object->GetShapeAccordingToFormat(shapeAndFormatInfoInput); + + // print some info + OP_LOGI(op.GetName().c_str(), "current input shape %s is %s", input_name.c_str(), + to_string(current_shape).c_str()); + + tensor_desc_input->SetFormat(cur_input_format); + tensor_desc_input->SetShape(current_shape); + delete global_object; + } + } + + // transfer output's origin shape to current shape + Format ori_output_format, cur_output_format; + GeShape ori_infer_out_shape, current_out_shape; + std::string output_name; + for (map::iterator it = outputs.begin(); it != outputs.end(); it++) { + output_name = it->first; + tensor_desc_output = op_desc->MutableOutputDesc(output_name); + if (tensor_desc_output == nullptr) { + continue; + } + ori_output_format = tensor_desc_output->GetFormat(); + ori_infer_out_shape = tensor_desc_output->GetShape(); + cur_output_format = map_format[output_name]; + + // print some info + OP_LOGI(op.GetName().c_str(), "origin output shape %s is %s", output_name.c_str(), + to_string(ori_infer_out_shape).c_str()); + + ShapeAndFormat shapeAndFormatInfoOutput = {ori_infer_out_shape, current_out_shape, ori_output_format, + cur_output_format, map_dtype[output_name], opImplType}; + if (ori_output_format == cur_output_format) { + // no need to transfer shape + continue; + } else { + ShapeTransferAccordingToFormat* global_object = new ShapeTransferAccordingToFormat(); + CHECK(global_object == nullptr, OP_LOGE(op.GetName().c_str(), "new ShapeTransferAccordingToFormat failed."), + return false); + global_object->GetShapeAccordingToFormat(shapeAndFormatInfoOutput); + + // print some info + OP_LOGI(op.GetName().c_str(), "current output shape %s is %s", output_name.c_str(), + to_string(current_out_shape).c_str()); + + tensor_desc_output->SetFormat(cur_output_format); + tensor_desc_output->SetShape(current_out_shape); + delete global_object; + } + } + + return true; +} + +bool IsEmptyTensor(const std::vector& dims) { + if (dims.size() == 1 && dims[0] == 0) { + return true; + } else { + return false; + } +} + +bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + GeTensorDescPtr tensor_desc; + if (types == "input") { + tensor_desc = op_desc->MutableInputDesc(tensor_name); + } else if (types == "output") { + tensor_desc = op_desc->MutableOutputDesc(tensor_name); + } else { + OP_LOGE(op.GetName().c_str(), "invalid params of types to judge."); + return false; + } + + std::vector shape_vec = tensor_desc->GetShape().GetDims(); + if (shape_vec.size() == 1 && shape_vec[0] == -2) { + return true; + } + return false; +} + +bool IsUnknownRankShape(const std::vector& shape_vec) { + if (shape_vec.size() == 1 && shape_vec[0] == -2) { + return true; + } + return false; +} + +bool IsUnKnownShape(const std::vector& shape_vec) { + auto found = find(shape_vec.begin(), shape_vec.end(), -1); + return found != shape_vec.end(); +} + +bool IsUnknown(const std::vector& shape_vec) { + return (IsUnKnownShape(shape_vec) || IsUnknownRankShape(shape_vec)); +} + +bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types) { + auto op_desc = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_desc == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + GeTensorDescPtr tensor_desc; + if (types == "input") { + tensor_desc = op_desc->MutableInputDesc(tensor_name); + } else if (types == "output") { + tensor_desc = op_desc->MutableOutputDesc(tensor_name); + } else { + OP_LOGE(op.GetName().c_str(), "invalid params of types to judge."); + return false; + } + + std::vector shape_vec = tensor_desc->GetShape().GetDims(); + std::vector::iterator it_shape; + it_shape = find(shape_vec.begin(), shape_vec.end(), -1); + if (it_shape == shape_vec.end()) { + return false; + } else { + return true; + } +} + +bool IsUnknownVec(std::vector& shape_vec) { + std::vector::iterator it_shape; + it_shape = find(shape_vec.begin(), shape_vec.end(), -1); + if (it_shape == shape_vec.end()) { + return false; + } else { + return true; + } +} + +void MakeUpShapeRange(const std::vector& shape, std::vector>& range) { + if (IsUnknownRankShape(shape)) { + return; + } + + if (range.empty()) { + for (size_t i = 0; i < shape.size(); i++) { + if (shape[i] == -1) { + range.push_back(std::pair(1, -1)); + } else { + range.push_back(std::pair(shape[i], shape[i])); + } + } + } +} + +std::string DataTypeToStringDesc(const ge::DataType& dataType) { + std::map::const_iterator totalIter = DTYPE_STR_MAP.find(dataType); + if (totalIter == DTYPE_STR_MAP.end()) { + return "UNDEFINED"; + } + return totalIter->second; +} + +bool OneInOneOutDynamicInfer(const Operator& op, + const std::string& input_name, + const std::vector& output_name_list) { + // get input desc + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_info == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), return false); + auto input_desc = op_info->MutableInputDesc(input_name); + vector input_shape = input_desc->MutableShape().GetDims(); + DataType input_dtype = input_desc->GetDataType(); + + if (IsUnknown(input_shape)) { + std::vector> input_range; + input_desc->GetShapeRange(input_range); + MakeUpShapeRange(input_shape, input_range); + + auto output_desc = op_info->MutableOutputDesc(0); + for (const string& output_name : output_name_list) { + output_desc = op_info->MutableOutputDesc(output_name); + output_desc->SetShape(GeShape(input_shape)); + output_desc->SetOriginShape(GeShape(input_shape)); + output_desc->SetShapeRange(input_range); + output_desc->SetDataType(input_dtype); + } + } else { + auto output_desc = op_info->MutableOutputDesc(0); + for (const string& output_name : output_name_list) { + output_desc = op_info->MutableOutputDesc(output_name); + output_desc->SetShape(GeShape(input_shape)); + output_desc->SetDataType(input_dtype); + } + } + return true; +} + +void FixShapeRangeWithDims(const std::vector& dims, + std::vector& shape_1, + std::vector& shape_2, + std::vector>& range_1, + std::vector>& range_2) { + MakeUpShapeRange(shape_1, range_1); + MakeUpShapeRange(shape_2, range_2); + bool is_all_fix = dims.empty(); + + if (shape_1 == UNKNOWN_RANK && shape_2 == UNKNOWN_RANK) { + return; + } + if (shape_1 == UNKNOWN_RANK) { + shape_1 = shape_2; + range_1 = range_2; + return; + } + if (shape_2 == UNKNOWN_RANK) { + shape_2 = shape_1; + range_2 = range_1; + return; + } + if ((shape_1.size() != shape_2.size()) || (range_1.size() != range_2.size())) { + return; + } + auto loop_size = is_all_fix ? shape_1.size() : dims.size(); + for (size_t i = 0; i < loop_size; i ++) { + auto dim_num = is_all_fix ? i : dims[i]; + if (shape_1[dim_num] != -1) { + shape_2[dim_num] = shape_1[dim_num]; + range_1[dim_num] = std::pair(shape_1[dim_num], shape_1[dim_num]); + range_2[dim_num] = std::pair(shape_1[dim_num], shape_1[dim_num]); + continue; + } + if (shape_2[dim_num] != -1) { + shape_1[dim_num] = shape_2[dim_num]; + range_1[dim_num] = std::pair(shape_2[dim_num], shape_2[dim_num]); + range_2[dim_num] = std::pair(shape_2[dim_num], shape_2[dim_num]); + continue; + } + // both the dim in shape1 and shape2 are -1 + auto range_1_min = range_1[dim_num].first; + auto range_2_min = range_2[dim_num].first; + auto range_1_max = range_1[dim_num].second; + auto range_2_max = range_2[dim_num].second; + auto range_fisrt = range_1_min > range_2_min ? range_1_min : range_2_min; + auto range_second_min = range_1_max > range_2_max ? range_2_max : range_1_max; + auto range_second_max = range_1_max > range_2_max ? range_1_max : range_2_max; + range_second_min = range_second_min == -1 ? range_second_max : range_second_min; + range_1[dim_num] = std::pair(range_fisrt, range_second_min); + range_2[dim_num] = std::pair(range_fisrt, range_second_min); + } +} + +bool TwoInOneOutDynamicInferNoBroadcast(Operator& op, + const string& input1_name, + const string& input2_name, + const std::vector& output_name_list) { + // get input1 desc + auto op_info = OpDescUtils::GetOpDescFromOperator(op); + CHECK(op_info == nullptr || op_info->MutableInputDesc(input1_name) == nullptr || + op_info->MutableInputDesc(input2_name) == nullptr, OP_LOGE(op.GetName().c_str(), "invalid OpDesc."), + return false); + auto input1_desc = op_info->MutableInputDesc(input1_name); + vector input1_shape = input1_desc->MutableShape().GetDims(); + DataType input_dtype = input1_desc->GetDataType(); + + // get input2 desc + auto input2_desc = op_info->MutableInputDesc(input2_name); + vector input2_shape = input2_desc->MutableShape().GetDims(); + + if (IsUnknown(input1_shape) || IsUnknown(input2_shape)) { + std::vector> input1_range; + input1_desc->GetShapeRange(input1_range); + std::vector> input2_range; + input2_desc->GetShapeRange(input2_range); + + vector dim_size = {}; + FixShapeRangeWithDims(dim_size, input1_shape, input2_shape, input1_range, input2_range); + + // update output desc + auto output_desc = op_info->MutableOutputDesc(0); + for (const string& output_name : output_name_list) { + output_desc = op_info->MutableOutputDesc(output_name); + output_desc->SetShape(GeShape(input1_shape)); + output_desc->SetOriginShape(GeShape(input1_shape)); + output_desc->SetShapeRange(input1_range); + output_desc->SetDataType(input_dtype); + } + } else { + auto output_desc = op_info->MutableOutputDesc(0); + for (const string& output_name : output_name_list) { + output_desc = op_info->MutableOutputDesc(output_name); + output_desc->SetShape(GeShape(input1_shape)); + output_desc->SetDataType(input_dtype); + } + } + return true; +} + +bool SetScalarOutputDesc(const string& input, const string& output, OpDescPtr op_desc, GeShape& output_shape) { + if (output_shape.IsScalar()) { + auto td = op_desc->MutableOutputDesc(output); + td->SetShape(output_shape); + td->SetOriginShape(output_shape); + td->SetDataType(op_desc->MutableInputDesc(input)->GetDataType()); + td->SetOriginDataType(op_desc->MutableInputDesc(input)->GetDataType()); + return true; + } else { + return false; + } +} + +namespace array_ops { + +bool CheckInt64MulOverflow(int64_t a, int64_t b) { + if (a > 0) { + if (b > 0) { + if (a >(INT64_MAX / b)) { + return false; + } + } else { + if (b < (INT64_MIN / a)) { + return false; + } + } + } else { + if (b > 0) { + if (a < (INT64_MIN / b)) { + return false; + } + } else { + if ((a != 0) && (b < (INT64_MAX / a))) { + return false; + } + } + } + + return true; +} + +void ReshapeRangeInfer(const Operator &op, const std::vector>& x_range, + int64_t& range_max) { + for (const auto& ele : x_range) { + if (ele.second < 0) { + range_max = -1; + return; + } + + if (array_ops::CheckInt64MulOverflow(range_max, ele.second)) { + range_max *= ele.second; + } else { + range_max = INT64_MAX; + GE_OP_LOGW(op.GetName().c_str(), "Range Infer out of int64 max!Do set int64max!"); + return; + } + } +} + +void ReshapeRangeInfer(const Operator &op, const std::vector>& x_range, + std::vector>& y_range, GeShape& output_shape) { + int64_t max_input_dims = 1; + for (const auto& pair : x_range) { + if (pair.second < 0) { + max_input_dims = -1; + break; + } + if (array_ops::CheckInt64MulOverflow(max_input_dims, pair.second)) { + max_input_dims *= pair.second; + } else { + max_input_dims = INT64_MAX; + GE_OP_LOGW(op.GetName().c_str(), "Range Infer out of int64 max!Do set int64max!"); + break; + } + } + + if (max_input_dims < 0) { + for (const auto dim : output_shape.GetDims()) { + if (dim < 0) { + y_range.emplace_back(std::pair(1, -1)); + } else { + y_range.emplace_back(std::pair(dim, dim)); + } + } + } else { + int64_t left = max_input_dims; + left = (left > INT32_MAX) ? INT32_MAX : left; + for (const auto dim : output_shape.GetDims()) { + if (dim < 0) { + y_range.emplace_back(std::pair(1, left)); + } else { + y_range.emplace_back(std::pair(dim, dim)); + if (dim != 0) { + left = static_cast((static_cast(left) + 0.5) / dim); + } + } + } + } +} + +} + +} // namespace ge + diff --git a/tests/st/framework/stub_op_proto/util/util.h b/tests/st/framework/stub_op_proto/util/util.h new file mode 100644 index 00000000..c1bf775d --- /dev/null +++ b/tests/st/framework/stub_op_proto/util/util.h @@ -0,0 +1,363 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file util.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ +#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "framework/omg/omg_inner_types.h" +#include "operator.h" +#include "graph/operator_reg.h" +#include "graph/operator_reg.h" +#include "transfer_shape_according_to_format.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/tensor.h" +#include "graph/node.h" +#include "graph/ge_tensor.h" + +#include "op_log.h" + +#define LOG_ERROR(format, args...) printf(format, ##args) +namespace ge { + +// enum type and string type mapping +static const std::map DTYPE_STR_MAP{ + {ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"}, + {ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, + {ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}}; + +// define the input num of shape +const size_t INPUT_NUM0 = 0; +const size_t INPUT_NUM1 = 1; +const size_t INPUT_NUM2 = 2; +const size_t INPUT_NUM3 = 3; +const size_t INPUT_NUM4 = 4; +const size_t INPUT_NUM5 = 5; +const size_t INPUT_NUM6 = 6; +const size_t INPUT_NUM7 = 7; +const size_t INPUT_NUM8 = 8; +const size_t INPUT_NUM9 = 9; + +// define the dims size of shape +const size_t DIM_SIZE0 = 0; +const size_t DIM_SIZE1 = 1; +const size_t DIM_SIZE2 = 2; +const size_t DIM_SIZE3 = 3; +const size_t DIM_SIZE4 = 4; +const size_t DIM_SIZE5 = 5; +const size_t DIM_SIZE6 = 6; +const size_t DIM_SIZE7 = 7; +const size_t DIM_SIZE8 = 8; + +// define the index of shape dim +const size_t DIM_INDEX0 = 0; +const size_t DIM_INDEX1 = 1; +const size_t DIM_INDEX2 = 2; +const size_t DIM_INDEX3 = 3; +const size_t DIM_INDEX4 = 4; +const size_t DIM_INDEX5 = 5; +const size_t DIM_INDEX6 = 6; +const size_t DIM_INDEX7 = 7; +const size_t DIM_INDEX8 = 8; + +/* + * get the datatype of input + * param[in] dataType input datatype of enum value + * param[in] supportList the support range of op + * return true :get type success + * false:get type failed + */ +bool GetInputDataType(const ge::DataType& data_type, const std::vector& supportList); + +bool GetInputDataType(const ge::DataType& dataType, const std::vector& supportList, std::string& dType); + +/* infer shape of two input and on output with broadcast + * param[in] op op desc supply by ge + * param[in] inputName1 first input name + * param[in] inputName2 second input name + * param[in] outputName output name + * return SUCCESS:infer success + * FAILED:infer failed like unsupported broadcast input shape + */ +bool CheckInputDataType(const Operator& op, const std::string& input_name, + const std::vector& support_list); + +/* + * check the datatype and shape of input + * param[in] op the operator + * param[in] inputTensorMap the map of input name and support datatype + * param[in] paramType the mode of input param, tensor or scalar + * return true + * false + */ +bool CheckInputDtypeAndShape(const Operator& op, const std::map>& inputTensorMap); + +/* + * infer shape of two input and on output with broadcast + * param[in] op op desc supply by ge + * param[in] inputName1 first input name + * param[in] inputName2 second input name + * param[in] outputName output name + * return SUCCESS:infer success + * FAILED:infer failed like unsupported broadcast input shape + */ +bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name); + +/* + * infer shape of two input and on output with broadcast + * param[in] op op desc supply by ge + * param[in] inputName1 first input name + * param[in] inputName2 second input name + * param[in] outputName output name + * param[in] is_dynamic whether the shape of output is dynamic shape + * return SUCCESS:infer success + * FAILED:infer failed like unsupported broadcast input shape + */ + +bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name, bool& is_dynamic); + +bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name); + +bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name, + const std::vector& supportList); + +bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2); + +bool CheckInputDtypeSame(const Operator& op, std::vector& input_tensors); + +bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector& input_names); + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value); + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value); + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value); + +bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector& attr_value); + +/** + * Get int type const value from tensor data + * @param [in] data const tensor data + * @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64 + * @param [out] const_values const int values + * @return true:success, false:failed. + */ +bool GetConstIntData(const Tensor& data, DataType data_type, std::vector& const_values); + +bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, + std::vector& const_data); +bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype, + std::vector& const_data); +bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data); +bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, + const string& output_name); + +/* + * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd + * param[in] op op desc supply by ge + * param[in] inputNumBeg input index begin, [0, N] + * param[in] inputNumEnd input index end need to be checked + * param[in] supportList, support type of ge::DataType and ge::Format + * return true: check pass + * false: check failed + */ +template +bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd, + const std::vector& supportList) { + for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) { + if (std::is_same::type, ge::DataType>::value) { + ge::DataType inType = op.GetInputDesc(i).GetDataType(); + const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); + if (findDtype == supportList.end()) { + return false; + } + } else if (std::is_same::type, ge::Format>::value) { + ge::Format inType = op.GetInputDesc(i).GetFormat(); + const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); + if (findDtype == supportList.end()) { + return false; + } + } + } + return true; +} + +/* + * Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd + * param[in] op op desc supply by ge + * param[in] indexNeedCheck input index need to be checked + * param[in] supportList, support type of ge::DataType and ge::Format + * return true: check pass + * false: check failed + */ +template +bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector& indexNeedCheck, + const std::vector& supportList) { + for (auto i : indexNeedCheck) { + if (std::is_same::type, ge::DataType>::value) { + ge::DataType inType = op.GetInputDesc(i).GetDataType(); + const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); + if (findDtype == supportList.end()) { + return false; + } + } else if (std::is_same::type, ge::Format>::value) { + ge::Format inType = op.GetInputDesc(i).GetFormat(); + const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); + if (findDtype == supportList.end()) { + return false; + } + } + } + return true; +} + +/* + * get const attr + * param[in] op op desc supply by ge + * param[in] attrName list need to be get + * param[out] attr vector + * return true: get success + * false: get failed + */ +template +bool GetConstAttr(const Operator& op, const std::vector& attrNameList, std::vector& attrVec) { + T value; + for (auto name : attrNameList) { + if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) { + return false; + } + attrVec.push_back(value); + } + return true; +} + +/* + * get const attr list + * param[in] op op desc supply by ge + * param[in] attrName list need to be get + * param[out] attr vector + * return true: get success + * false: get failed + */ +template +bool GetConstAttr(const Operator& op, const std::vector& attrNameList, + std::vector>& attrListVec) { + for (auto name : attrNameList) { + std::vector valueList; + if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) { + return false; + } + attrListVec.push_back(valueList); + } + return true; +} + +std::string to_string(const vector& shape); +std::string to_string(const ge::Shape& shape); +std::string to_string(const ge::GeShape& shape); +std::string to_string(const vector>& ranges); + +class DynamicShapeInfer { + public: + std::map map_format; + std::map map_dtype; + std::map inputs; + std::map outputs; + Operator& op; + OpDescPtr& op_desc; + + DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) { + } + bool CatchFormatAndShape(); + bool UpdateFormatAndShape(); + + ~DynamicShapeInfer() { + UpdateFormatAndShape(); + } +}; + +#define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\ + do { \ + if (!depends_names.empty()) { \ + op_desc->SetOpInferDepends(depends_names); \ + } \ + } while(0) + + +bool IsEmptyTensor(const std::vector& dims); + +bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); + +bool IsUnknownRankShape(const std::vector& shape_vec); + +bool IsUnKnownShape(const std::vector& shape_vec); + +bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); + +bool IsUnknownVec(std::vector& shape_vec); + +bool IsUnknown(const std::vector& shape_vec); + +void MakeUpShapeRange(const std::vector& shape, std::vector>& range); + +std::string DataTypeToStringDesc(const ge::DataType& dataType); + +bool OneInOneOutDynamicInfer(const Operator& op, + const std::string& input_name, + const std::vector& output_name_list); + +bool TwoInOneOutDynamicInferNoBroadcast(Operator& op, + const string& input1_name, + const string& input2_name, + const std::vector& output_name_list); + +void FixShapeRangeWithDims(const std::vector& dims, + std::vector& shape_1, + std::vector& shape_2, + std::vector>& range_1, + std::vector>& range_2); + +bool SetScalarOutputDesc(const string& input, + const string& output, + OpDescPtr op_desc, + GeShape& output_shape); + +namespace array_ops { +bool CheckInt64MulOverflow(int64_t a, int64_t b); + +void ReshapeRangeInfer(const Operator &op, const std::vector>& x_range, + int64_t& range_max); + +void ReshapeRangeInfer(const Operator &op, const std::vector>& x_range, + std::vector>& y_range, GeShape& output_shape); +} +} // namespace ge + +#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ diff --git a/tests/st/framework/utils/assertion/graph_assertion.cc b/tests/st/framework/utils/assertion/graph_assertion.cc new file mode 100644 index 00000000..52c49971 --- /dev/null +++ b/tests/st/framework/utils/assertion/graph_assertion.cc @@ -0,0 +1,17 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_assertion.h" diff --git a/tests/st/framework/utils/assertion/graph_assertion.h b/tests/st/framework/utils/assertion/graph_assertion.h new file mode 100644 index 00000000..ffdceaf9 --- /dev/null +++ b/tests/st/framework/utils/assertion/graph_assertion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H +#define GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H + +/* + * Compare graph node size, node_attr + */ +#define ASSERT_GRAPH_EQUAL(g1,g2) \ + do { \ + } while (0) + +#define ASSERT_GRAPH_CORRECT(g) \ + do { \ + } while (0) + +#define ASSERT_GRAPH_SHAPE_CONTINOUS(g) \ + do { \ + } while (0) +#endif // GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H diff --git a/tests/st/framework/utils/builder/graph_builder_utils.cc b/tests/st/framework/utils/builder/graph_builder_utils.cc new file mode 100644 index 00000000..cab78284 --- /dev/null +++ b/tests/st/framework/utils/builder/graph_builder_utils.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_builder_utils.h" +#include "inc/external/graph/operator.h" +#include "inc/external/graph/operator_factory.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +namespace st { +NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, + DataType data_type, std::vector shape) { + auto tensor_desc = std::make_shared(); + tensor_desc->SetShape(GeShape(std::move(shape))); + tensor_desc->SetFormat(format); + tensor_desc->SetDataType(data_type); + + auto op_desc = std::make_shared(name, type); + for (int i = 0; i < in_cnt; ++i) { + op_desc->AddInputDesc(tensor_desc->Clone()); + } + for (int i = 0; i < out_cnt; ++i) { + op_desc->AddOutputDesc(tensor_desc->Clone()); + } + + return graph_->AddNode(op_desc); +} +void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { + GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); +} +void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { + GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); +} +} // namespace st +} // namespace ge diff --git a/tests/st/framework/utils/builder/graph_builder_utils.h b/tests/st/framework/utils/builder/graph_builder_utils.h new file mode 100644 index 00000000..cf1cff2e --- /dev/null +++ b/tests/st/framework/utils/builder/graph_builder_utils.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H +#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H + +#include +#include + +#include "graph/compute_graph.h" +#include "graph/utils/graph_utils.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +namespace st { +class ComputeGraphBuilder { + public: + explicit ComputeGraphBuilder(const std::string &name) { graph_ = std::make_shared(name); } + NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, + Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, + std::vector shape = {1, 1, 224, 224}); + void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); + void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); + ComputeGraphPtr GetComputeGraph() { + graph_->TopologicalSorting(); + return graph_; + } + Graph GetGraph() { + graph_->TopologicalSorting(); + return GraphUtils::CreateGraphFromComputeGraph(graph_); + } + + private: + ComputeGraphPtr graph_; +}; +} // namespace st +} // namespace ge + +#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H diff --git a/tests/st/framework/utils/builder/tensor_builder_utils.cc b/tests/st/framework/utils/builder/tensor_builder_utils.cc new file mode 100644 index 00000000..f99b9107 --- /dev/null +++ b/tests/st/framework/utils/builder/tensor_builder_utils.cc @@ -0,0 +1,17 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensor_builder_utils.h" diff --git a/tests/st/framework/utils/builder/tensor_builder_utils.h b/tests/st/framework/utils/builder/tensor_builder_utils.h new file mode 100644 index 00000000..73656e4a --- /dev/null +++ b/tests/st/framework/utils/builder/tensor_builder_utils.h @@ -0,0 +1,22 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H +#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H + +class tensor_builder_utils {}; + +#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H diff --git a/tests/st/testcase/CMakeLists.txt b/tests/st/testcase/CMakeLists.txt new file mode 100644 index 00000000..748e740f --- /dev/null +++ b/tests/st/testcase/CMakeLists.txt @@ -0,0 +1,15 @@ +file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") + +add_executable(graph_engine_test ${SOURCES}) + +target_include_directories(graph_engine_test + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 11) + +target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) + +include(CTest) +enable_testing() +add_test(NAME test COMMAND graph_engine_test) \ No newline at end of file diff --git a/tests/st/testcase/test_framework_dummy.cc b/tests/st/testcase/test_framework_dummy.cc new file mode 100644 index 00000000..46485030 --- /dev/null +++ b/tests/st/testcase/test_framework_dummy.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "external/ge/ge_api.h" +#include "framework/common/types.h" +#include "framework.h" +#include "framework/utils/builder/graph_builder_utils.h" + +using namespace std; +using namespace ge; +class FrameworkTest : public testing::Test { + protected: + void SetUp() { + // ge initialize + map options; + auto ret = ge::GEInitialize(options); + EXPECT_EQ(ret, SUCCESS); + } + void TearDown() {} +}; + +TEST_F(FrameworkTest, test_framework_dummy) { + // build graph + st::ComputeGraphBuilder graphBuilder("g1"); + auto data1 = graphBuilder.AddNode("data1",DATA,1,1); + auto data2 = graphBuilder.AddNode("data2",DATA,1,1); + auto add = graphBuilder.AddNode("add",ADD,2,1); + graphBuilder.AddDataEdge(data1, 0, add,0); + graphBuilder.AddDataEdge(data2, 0, add,1); + Graph graph = graphBuilder.GetGraph(); + + // new session & add graph + map options; + Session session(options); + auto ret = session.AddGraph(1, graph, options); + EXPECT_EQ(ret, SUCCESS); + // build input tensor + std::vector inputs; + // build_graph through session + ret = session.BuildGraph(1, inputs); + + // TODO check result +} \ No newline at end of file