@@ -39,7 +39,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} | |||||
option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
if (ENABLE_OPEN_SRC) | |||||
if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||||
set(HI_PYTHON python3) | set(HI_PYTHON python3) | ||||
include(cmake/external_libs/protobuf_shared.cmake) | include(cmake/external_libs/protobuf_shared.cmake) | ||||
@@ -51,118 +51,132 @@ if (ENABLE_OPEN_SRC) | |||||
include(cmake/external_libs/json.cmake) | include(cmake/external_libs/json.cmake) | ||||
include(cmake/FindModule.cmake) | include(cmake/FindModule.cmake) | ||||
include(cmake/intf_pub_linux.cmake) | include(cmake/intf_pub_linux.cmake) | ||||
# if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
if(DEFINED ENV{D_LINK_PATH}) | |||||
# D_LINK_PATH is set | |||||
set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
set(GE_SYS_ARCH "") | |||||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
# x86 ubuntu | |||||
set(GE_SYS_ARCH "x86_64") | |||||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
# arm euleros | |||||
set(GE_SYS_ARCH "aarch64") | |||||
add_subdirectory(tests) | |||||
else () | |||||
if (ENABLE_OPEN_SRC) | |||||
set(HI_PYTHON python3) | |||||
include(cmake/external_libs/protobuf_shared.cmake) | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/gflags.cmake) | |||||
include(cmake/external_libs/gtest.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/external_libs/json.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# if D_LINK_PATH is set in environment variables, search libraries in given path | |||||
if(DEFINED ENV{D_LINK_PATH}) | |||||
# D_LINK_PATH is set | |||||
set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||||
set(GE_SYS_ARCH "") | |||||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||||
# x86 ubuntu | |||||
set(GE_SYS_ARCH "x86_64") | |||||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||||
# arm euleros | |||||
set(GE_SYS_ARCH "aarch64") | |||||
else() | |||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
endif() | |||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||||
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||||
find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||||
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||||
find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||||
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
find_module(resource libresource.so ${GE_LIB_PATH}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
else() | else() | ||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||||
endif() | |||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
find_module(slog libalog.so ${GE_LIB_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||||
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||||
find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||||
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||||
find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||||
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||||
find_module(resource libresource.so ${GE_LIB_PATH}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||||
elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
add_subdirectory(tests) | |||||
else() | |||||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||||
if(PLATFORM STREQUAL "train") | |||||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||||
if(PLATFORM STREQUAL "train") | |||||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
if(PRODUCT STREQUAL "flr3") | |||||
message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||||
endif() | |||||
elseif(PLATFORM STREQUAL "inference") | |||||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
if(PRODUCT STREQUAL "flr3") | |||||
elseif(PRODUCT STREQUAL "flr1") | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
elseif(PRODUCT STREQUAL "flr2") | |||||
# flr2 ascend_hal_stub limsprof ? | |||||
else() | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
endif() | |||||
elseif(PLATFORM STREQUAL "all") | |||||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | ||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
if(PRODUCT STREQUAL "flr3") | |||||
message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||||
endif() | |||||
elseif(PLATFORM STREQUAL "inference") | |||||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | ||||
if(PRODUCT STREQUAL "flr3") | |||||
elseif(PRODUCT STREQUAL "flr1") | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||||
elseif(PRODUCT STREQUAL "flr2") | |||||
# flr2 ascend_hal_stub limsprof ? | |||||
else() | else() | ||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
endif() | endif() | ||||
elseif(PLATFORM STREQUAL "all") | |||||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||||
else() | |||||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
endif() | endif() | ||||
endif() | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
add_subdirectory(metadef) | |||||
add_subdirectory(parser) | |||||
#add_subdirectory(metadef/graph) | |||||
#add_subdirectory(metadef/register) | |||||
elseif (ENABLE_D OR ENABLE_ACL) | |||||
# compiling with MindSpore | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/external_libs/json.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# common libraries | |||||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
if (ENABLE_D) | |||||
# training | |||||
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
endif () | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
add_subdirectory(metadef) | |||||
elseif(ENABLE_MS_TESTCASES) | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# common libraries | |||||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
add_subdirectory(metadef) | |||||
add_subdirectory(parser) | |||||
#add_subdirectory(metadef/graph) | |||||
#add_subdirectory(metadef/register) | |||||
elseif (ENABLE_D OR ENABLE_ACL) | |||||
# compiling with MindSpore | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/external_libs/json.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# common libraries | |||||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
if (ENABLE_D) | |||||
# training | |||||
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
endif () | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
add_subdirectory(metadef) | |||||
elseif(ENABLE_MS_TESTCASES) | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# common libraries | |||||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
add_subdirectory(metadef) | |||||
else() | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
endif() | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
add_subdirectory(metadef) | |||||
else() | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||||
endif() | |||||
add_subdirectory(ge) | |||||
add_subdirectory(ge) | |||||
endif () |
@@ -177,6 +177,9 @@ build_graphengine() | |||||
elif [ "X$ENABLE_GE_UT" = "Xon" ] | elif [ "X$ENABLE_GE_UT" = "Xon" ] | ||||
then | then | ||||
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | ||||
elif [ "X$ENABLE_GE_ST" = "Xon" ] | |||||
then | |||||
TARGET="graph_engine_test" | |||||
elif [ "X$MINDSPORE_MODE" = "Xon" ] | elif [ "X$MINDSPORE_MODE" = "Xon" ] | ||||
then | then | ||||
TARGET="ge_common graph" | TARGET="ge_common graph" | ||||
@@ -234,6 +237,27 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
genhtml coverage.info | genhtml coverage.info | ||||
fi | fi | ||||
if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
#prepare engine & opskernel so | |||||
mkdir -p ${OUTPUT_PATH}/plugin/nnengine | |||||
mkdir -p ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||||
mkdir -p ${OUTPUT_PATH}/plugin/opskernel | |||||
cp ${BUILD_PATH}/tests/st/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine | |||||
cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||||
cp ${BUILD_PATH}/tests/st/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel | |||||
#prepare st execution bin | |||||
cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} | |||||
#execute st testcase | |||||
RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} | |||||
if [[ "$?" -ne 0 ]]; then | |||||
echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
echo -e "\033[31m${RUN_TEST_CASE}\033[0m" | |||||
exit 1; | |||||
fi | |||||
# remove plugin | |||||
rm -rf ${OUTPUT_PATH}/plugin | |||||
fi | |||||
# generate output package in tar form, including ut/st libraries/executables | # generate output package in tar form, including ut/st libraries/executables | ||||
generate_package() | generate_package() | ||||
{ | { | ||||
@@ -337,7 +361,7 @@ generate_package() | |||||
fi | fi | ||||
} | } | ||||
if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$ENABLE_GE_ST" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||||
generate_package | generate_package | ||||
elif [ "X$MINDSPORE_MODE" = "Xon" ] | elif [ "X$MINDSPORE_MODE" = "Xon" ] | ||||
then | then | ||||
@@ -25,6 +25,7 @@ | |||||
#include "framework/common/op/op_parser_util.h" | #include "framework/common/op/op_parser_util.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "task/task_factory.h" | #include "task/task_factory.h" | ||||
#include "ge/common/math/math_util.h" | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
@@ -500,7 +501,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
} | } | ||||
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | ||||
uint32_t head_len = kOffsetUnit * kStringHeadElems; | uint32_t head_len = kOffsetUnit * kStringHeadElems; | ||||
if (ge::CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||||
if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||||
GELOGE(FAILED, "Shape size is invalid"); | GELOGE(FAILED, "Shape size is invalid"); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -83,7 +83,7 @@ bool AicpuTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("ext info size:", ext_size); | |||||
GELOGI("ext info size: %u", ext_size); | |||||
aicpu_param_head.extInfoLength = ext_size; | aicpu_param_head.extInfoLength = ext_size; | ||||
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | ||||
} | } | ||||
@@ -130,7 +130,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
Status ret; | Status ret; | ||||
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | ||||
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | ||||
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id); | |||||
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); | |||||
ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ||||
if (!ret) { | if (!ret) { | ||||
GELOGE(RT_FAILED, "Create hccl stream failed."); | GELOGE(RT_FAILED, "Create hccl stream failed."); | ||||
@@ -189,7 +189,7 @@ bool HcclTask::SetSecondaryStream() { | |||||
} | } | ||||
GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | ||||
} else { | } else { | ||||
GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(), | |||||
GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), | |||||
master_stream_id); | master_stream_id); | ||||
ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | ||||
if (!ret) { | if (!ret) { | ||||
@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | ||||
return false; | return false; | ||||
@@ -69,7 +69,7 @@ bool LabelSwitchTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
label_list[i] = all_label_resource_[label_index]; | label_list[i] = all_label_resource_[label_index]; | ||||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); | |||||
} | } | ||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | ||||
@@ -0,0 +1,6 @@ | |||||
project(graphengine_st) | |||||
include(cmake/graphengine.cmake) | |||||
add_subdirectory(framework) | |||||
add_subdirectory(testcase) |
@@ -0,0 +1,249 @@ | |||||
# ---- Test coverage ---- | |||||
if (ENABLE_GE_COV) | |||||
set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -fPIC -O0 -ftest-coverage") | |||||
set(CMAKE_CXX_FLAGS "${COVERAGE_COMPILER_FLAGS}") | |||||
endif() | |||||
# ---- Proto generate ---- | |||||
file(GLOB_RECURSE PROTO_FILES CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/proto/*.proto") | |||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_FILES}) | |||||
# ---- File glob by group ---- | |||||
file(GLOB_RECURSE METADEF_SRCS CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/metadef/graph/*.cc" | |||||
"${GE_CODE_DIR}/metadef/register/*.cc" | |||||
"${GE_CODE_DIR}/metadef/register/*.cpp" | |||||
"${GE_CODE_DIR}/metadef/ops/*.cc" | |||||
"${GE_CODE_DIR}/metadef/third_party/transformer/src/*.cc" | |||||
) | |||||
file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/metadef/register/*.cc" | |||||
"${GE_CODE_DIR}/metadef/register/*.cpp" | |||||
) | |||||
file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/parser/parser/common/*.cc" | |||||
) | |||||
file(GLOB_RECURSE LOCAL_ENGINE_SRC CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/ge/ge_local_engine/*.cc" | |||||
) | |||||
file(GLOB_RECURSE HOST_ENGINE_SRC CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/ge/host_cpu_engine/*.cc" | |||||
) | |||||
file(GLOB_RECURSE NN_ENGINE_SRC CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/ge/plugin/*.cc" | |||||
) | |||||
file(GLOB_RECURSE OFFLINE_SRC CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/ge/offline/*.cc" | |||||
) | |||||
file(GLOB_RECURSE GE_SRCS CONFIGURE_DEPENDS | |||||
"${GE_CODE_DIR}/ge/*.cc" | |||||
) | |||||
list(REMOVE_ITEM GE_SRCS ${LOCAL_ENGINE_SRC} ${HOST_ENGINE_SRC} ${NN_ENGINE_SRC} ${OFFLINE_SRC}) | |||||
list(APPEND INCLUDE_DIRECTORIES | |||||
"${CMAKE_CURRENT_SOURCE_DIR}" | |||||
"${GE_CODE_DIR}" | |||||
"${GE_CODE_DIR}/inc" | |||||
"${GE_CODE_DIR}/metadef/inc" | |||||
"${GE_CODE_DIR}/ge" | |||||
"${GE_CODE_DIR}/ge/inc" | |||||
"${GE_CODE_DIR}/ge/ir_build" | |||||
"${GE_CODE_DIR}/metadef" | |||||
"${GE_CODE_DIR}/metadef/graph" | |||||
"${GE_CODE_DIR}/inc/external" | |||||
"${GE_CODE_DIR}/inc/framework/common" | |||||
"${GE_CODE_DIR}/metadef/inc/external" | |||||
"${GE_CODE_DIR}/metadef/inc/external/graph" | |||||
"${GE_CODE_DIR}/metadef/inc/graph" | |||||
"${GE_CODE_DIR}/inc/framework" | |||||
"${GE_CODE_DIR}/metadef/inc/common" | |||||
"${GE_CODE_DIR}/metadef/third_party" | |||||
"${GE_CODE_DIR}/metadef/third_party/transformer/inc" | |||||
"${GE_CODE_DIR}/parser" | |||||
"${GE_CODE_DIR}/parser/parser" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | |||||
"${GE_CODE_DIR}/tests/ut/ge" | |||||
"${GE_CODE_DIR}/tests/ut/common" | |||||
"${CMAKE_BINARY_DIR}" | |||||
"${CMAKE_BINARY_DIR}/proto/ge" | |||||
"${CMAKE_BINARY_DIR}/proto/ge/proto" | |||||
) | |||||
list(APPEND STUB_LIBS | |||||
c_sec | |||||
slog_stub | |||||
cce_ge_stub | |||||
runtime_stub | |||||
profiler_stub | |||||
#mmpa_stub | |||||
hccl_stub | |||||
error_manager_stub | |||||
ascend_protobuf | |||||
json | |||||
) | |||||
# ---- Target : Local engine ---- | |||||
add_library(localengine STATIC ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) | |||||
target_include_directories(localengine | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
"${GE_CODE_DIR}/ge/ge_local_engine" | |||||
) | |||||
target_compile_definitions(localengine PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(localengine PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(localengine PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(localengine PROPERTIES CXX_STANDARD 11) | |||||
# ---- Target : metadef graph ---- | |||||
add_library(metadef_graph STATIC ${METADEF_SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
target_include_directories(metadef_graph | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
) | |||||
target_compile_definitions(metadef_graph PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(metadef_graph PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(metadef_graph PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) | |||||
# ---- Target : Host engine ---- | |||||
add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC} ${PROTO_HDRS}) | |||||
target_include_directories(host_cpu_engine | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
"${GE_CODE_DIR}/ge/host_cpu_engine" | |||||
) | |||||
target_compile_definitions(host_cpu_engine PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(host_cpu_engine PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(host_cpu_engine PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lmmpa -L/home/hugo/Code/ge/graphengine/build/tests/depends/mmpa -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) | |||||
# ---- Target : engine plugin---- | |||||
# | |||||
add_library(nnengine SHARED ${NN_ENGINE_SRC}) | |||||
target_include_directories(nnengine | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
"${GE_CODE_DIR}/ge/plugin/engine" | |||||
) | |||||
target_compile_definitions(nnengine PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(nnengine PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(nnengine PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(nnengine PROPERTIES CXX_STANDARD 11) | |||||
# Targe: engine_conf | |||||
add_custom_target( | |||||
engine_conf.json ALL | |||||
DEPENDS ${CMAKE_BINARY_DIR}/engine_conf.json | |||||
) | |||||
add_custom_command( | |||||
OUTPUT ${CMAKE_BINARY_DIR}/engine_conf.json | |||||
COMMAND cp ${GE_CODE_DIR}/ge/engine_manager/engine_conf.json ${CMAKE_BINARY_DIR}/ | |||||
) | |||||
# Targe: optimizer priority | |||||
add_custom_target( | |||||
optimizer_priority.pbtxt ALL | |||||
DEPENDS ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||||
) | |||||
add_custom_command( | |||||
OUTPUT ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||||
COMMAND cp ${GE_CODE_DIR}/ge/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_BINARY_DIR}/ | |||||
) | |||||
# ---- Target : Graph engine ---- | |||||
add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS} ${PROTO_HDRS}) | |||||
target_include_directories(graphengine | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
"${GE_CODE_DIR}/ge/host_cpu_engine" | |||||
) | |||||
target_compile_definitions(graphengine PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(graphengine PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(graphengine PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} | |||||
metadef_graph | |||||
localengine | |||||
host_cpu_engine | |||||
nnengine | |||||
mmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) | |||||
add_dependencies(graphengine engine_conf.json optimizer_priority.pbtxt) |
@@ -0,0 +1,16 @@ | |||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
#todo | |||||
file(GLOB_RECURSE stub_engine CONFIGURE_DEPENDS | |||||
"stub_engine/*.cc" | |||||
) | |||||
list(REMOVE_ITEM SOURCES ${stub_engine}) | |||||
add_library(framework STATIC ${SOURCES}) | |||||
target_include_directories(framework | |||||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} | |||||
) | |||||
set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||||
target_link_libraries(framework PUBLIC graphengine) |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <stdlib.h> | |||||
#include "framework.h" | |||||
namespace ge { | |||||
namespace st { | |||||
Status Framework::SetUp() { | |||||
} | |||||
} // namespace st | |||||
} // namespace ge |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||||
#define GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||||
#include <string> | |||||
#include "common/ge_inner_error_codes.h" | |||||
namespace ge { | |||||
namespace st { | |||||
class Framework { | |||||
public: | |||||
explicit Framework() {}; | |||||
Status SetUp(); | |||||
Status TearDown(); | |||||
}; | |||||
} // namespace st | |||||
}// namespace ge | |||||
#endif // GRAPHENGINE_LLT_ST_FRAMEWORK_H_ |
@@ -0,0 +1,259 @@ | |||||
set(PROTO_LIST | |||||
"${METADEF_DIR}/proto/task.proto" | |||||
) | |||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) | |||||
set(SRC_LIST | |||||
"engine/stub_engine.cc" | |||||
"ops_kernel_store/host_cpu_ops_kernel_info.cc" | |||||
"ops_kernel_store/op/op_factory.cc" | |||||
"ops_kernel_store/op/host_op.cc" | |||||
) | |||||
set(CPU_OPS_KERNEL_LIST | |||||
"ops_kernel_store/host_cpu_ops_kernel_builder.cc" | |||||
) | |||||
############ libfe.so ############ | |||||
add_library(fe SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
target_compile_options(fe PRIVATE | |||||
-Werror | |||||
-fno-common | |||||
-fvisibility=hidden | |||||
) | |||||
target_compile_definitions(fe PRIVATE | |||||
google=ascend_private | |||||
FUNC_VISIBILITY | |||||
) | |||||
target_include_directories(fe PRIVATE | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | |||||
target_link_options(fe PRIVATE | |||||
-Wl,-Bsymbolic | |||||
) | |||||
target_link_libraries(fe PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
c_sec | |||||
graph | |||||
slog | |||||
-Wl,--as-needed | |||||
) | |||||
############ atcstub/libfe.so ############ | |||||
add_library(atc_fe SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) | |||||
target_compile_options(atc_fe PRIVATE | |||||
-Werror | |||||
-fno-common | |||||
-fvisibility=hidden | |||||
) | |||||
target_compile_definitions(atc_fe PRIVATE | |||||
google=ascend_private | |||||
FUNC_VISIBILITY | |||||
) | |||||
target_include_directories(atc_fe PRIVATE | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge_atcstub | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | |||||
target_link_options(atc_fe PRIVATE | |||||
-Wl,-Bsymbolic | |||||
) | |||||
target_link_libraries(atc_fe PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
c_sec | |||||
graph | |||||
slog | |||||
-Wl,--as-needed | |||||
) | |||||
set_target_properties(atc_fe PROPERTIES | |||||
OUTPUT_NAME fe | |||||
LIBRARY_OUTPUT_DIRECTORY atclib | |||||
) | |||||
############ libhost_cpu_opskernel_builder.so ############ | |||||
add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||||
target_compile_options(host_cpu_opskernel_builder PRIVATE | |||||
-Werror | |||||
-fno-common | |||||
-fvisibility=hidden | |||||
) | |||||
target_compile_definitions(host_cpu_opskernel_builder PRIVATE | |||||
google=ascend_private | |||||
FUNC_VISIBILITY | |||||
) | |||||
target_include_directories(host_cpu_opskernel_builder PRIVATE | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | |||||
target_link_options(host_cpu_opskernel_builder PRIVATE | |||||
-Wl,-Bsymbolic | |||||
) | |||||
target_link_libraries(host_cpu_opskernel_builder PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
c_sec | |||||
slog | |||||
graph | |||||
register | |||||
-Wl,--as-needed | |||||
) | |||||
############ atclib/libhost_cpu_opskernel_builder.so ############ | |||||
add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||||
target_compile_options(atc_host_cpu_opskernel_builder PRIVATE | |||||
-Werror | |||||
-fno-common | |||||
-fvisibility=hidden | |||||
) | |||||
target_compile_definitions(atc_host_cpu_opskernel_builder PRIVATE | |||||
google=ascend_private | |||||
FUNC_VISIBILITY | |||||
) | |||||
target_include_directories(atc_host_cpu_opskernel_builder PRIVATE | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | |||||
target_link_options(atc_host_cpu_opskernel_builder PRIVATE | |||||
-Wl,-Bsymbolic | |||||
) | |||||
target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
c_sec | |||||
slog | |||||
graph | |||||
register | |||||
-Wl,--as-needed | |||||
) | |||||
set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES | |||||
OUTPUT_NAME host_cpu_opskernel_builder | |||||
LIBRARY_OUTPUT_DIRECTORY atclib | |||||
) | |||||
############ libhost_cpu_opskernel_builder.a ############ | |||||
add_library(host_cpu_opskernel_builder_static STATIC ${CPU_OPS_KERNEL_LIST}) | |||||
target_compile_options(host_cpu_opskernel_builder_static PRIVATE | |||||
-Werror | |||||
-fno-common | |||||
-fvisibility=hidden | |||||
) | |||||
target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE | |||||
google=ascend_private | |||||
LOG_CPP | |||||
FUNC_VISIBILITY | |||||
) | |||||
target_include_directories(host_cpu_opskernel_builder_static PRIVATE | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | |||||
target_link_libraries(host_cpu_opskernel_builder_static PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
ascend_protobuf | |||||
c_sec | |||||
) | |||||
############ install ############ | |||||
set(INSTALL_BASE_DIR "") | |||||
set(INSTALL_LIBRARY_DIR lib) | |||||
install(TARGETS fe host_cpu_opskernel_builder OPTIONAL | |||||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
) | |||||
install(TARGETS atc_fe atc_host_cpu_opskernel_builder OPTIONAL | |||||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib | |||||
) |
@@ -0,0 +1,30 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||||
#define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||||
#include <string> | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
// engine name | |||||
const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU"; | |||||
const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE"; | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ |
@@ -0,0 +1,74 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "stub_engine.h" | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <securec.h> | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "host_cpu_engine/common/constant/constant.h" | |||||
#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||||
namespace fe { | |||||
AICEngine &AICEngine::Instance() { | |||||
static AICEngine instance; | |||||
return instance; | |||||
} | |||||
Status AICEngine::Initialize(const std::map<string, string> &options) { | |||||
if (ops_kernel_store_ == nullptr) { | |||||
ops_kernel_store_ = MakeShared<HostCpuOpsKernelInfoStore>(); | |||||
if (ops_kernel_store_ == nullptr) { | |||||
GELOGE(FAILED, "[Create][AICEngine]Make HostCpuOpsKernelInfoStore failed."); | |||||
REPORT_INNER_ERROR("E19999", "AICEngine::Initialize failed for new AICEngine."); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void AICEngine::GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
if (ops_kernel_store_ != nullptr) { | |||||
// add buildin opsKernel to opsKernelInfoMap | |||||
ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_; | |||||
} | |||||
} | |||||
void AICEngine::GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &) { | |||||
// no optimizer for host cpu engine | |||||
} | |||||
Status AICEngine::Finalize() { | |||||
ops_kernel_store_ = nullptr; | |||||
return SUCCESS; | |||||
} | |||||
} // namespace fe | |||||
ge::Status Initialize(const std::map<string, string> &options) { | |||||
return fe::AICEngine::Instance().Initialize(options); | |||||
} | |||||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
fe::AICEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); | |||||
} | |||||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers) { | |||||
fe::AICEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); | |||||
} | |||||
ge::Status Finalize() { return fe::AICEngine::Instance().Finalize(); } |
@@ -0,0 +1,126 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||||
#define GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <vector> | |||||
#include <memory> | |||||
#include <string> | |||||
#include "common/opskernel/ops_kernel_info_store.h" | |||||
#include "common/optimizer/graph_optimizer.h" | |||||
using OpsKernelInfoStorePtr = std::shared_ptr<ge::OpsKernelInfoStore>; | |||||
using GraphOptimizerPtr = std::shared_ptr<ge::GraphOptimizer>; | |||||
namespace ge { | |||||
namespace { | |||||
std::vector<string> extern_engine_name_vec = {"fe","rts_engine","aicpu_ascend_engine","aicpu_tf_engine",} | |||||
} // namespace | |||||
/** | |||||
* host cpu engine. | |||||
* Used for the ops which executes on host. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY StubEngine { | |||||
public: | |||||
/** | |||||
* get HostCpuEngine instance. | |||||
* @return HostCpuEngine instance. | |||||
*/ | |||||
static StubEngine &Instance(); | |||||
virtual ~StubEngine() = default; | |||||
/** | |||||
* When Ge start, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
Status Initialize(const std::map<string, string> &options); | |||||
/** | |||||
* After the initialize, GE will invoke this interface | |||||
* to get the Ops kernel Store. | |||||
* @param ops_kernel_map The host cpu's ops kernel info | |||||
*/ | |||||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
/** | |||||
* After the initialize, GE will invoke this interface | |||||
* to get the Graph Optimizer. | |||||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
*/ | |||||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
/** | |||||
* When the graph finished, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
Status Finalize(); | |||||
StubEngine(const StubEngine &StubEngine) = delete; | |||||
StubEngine(const StubEngine &&StubEngine) = delete; | |||||
StubEngine &operator=(const StubEngine &StubEngine) = delete; | |||||
StubEngine &operator=(StubEngine &&StubEngine) = delete; | |||||
private: | |||||
StubEngine() = default; | |||||
OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; | |||||
}; | |||||
} // namespace ge | |||||
extern "C" { | |||||
/** | |||||
* When Ge start, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
GE_FUNC_VISIBILITY ge::Status Initialize(const map<string, string> &options); | |||||
/** | |||||
* After the initialize, GE will invoke this interface to get the Ops kernel Store | |||||
* @param ops_kernel_map The host cpu's ops kernel info | |||||
*/ | |||||
GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
/** | |||||
* After the initialize, GE will invoke this interface to get the Graph Optimizer | |||||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
*/ | |||||
GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
/** | |||||
* When the graph finished, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
GE_FUNC_VISIBILITY ge::Status Finalize(); | |||||
} | |||||
#endif // GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ |
@@ -0,0 +1,114 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "host_cpu_ops_kernel_builder.h" | |||||
#include <memory> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include <securec.h> | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "host_cpu_engine/common/constant/constant.h" | |||||
#include "register/ops_kernel_builder_registry.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
REGISTER_OPS_KERNEL_BUILDER(kHostCpuOpKernelLibName, HostCpuOpsKernelBuilder); | |||||
Status HostCpuOpsKernelBuilder::Finalize() { | |||||
return SUCCESS; | |||||
} | |||||
Status HostCpuOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) { | |||||
return SUCCESS; | |||||
} | |||||
Status HostCpuOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
OpDescPtr op_desc = ge_node.GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); | |||||
REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); | |||||
return FAILED; | |||||
} | |||||
bool is_shape_unknown = false; | |||||
if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
if (is_shape_unknown) { | |||||
GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
const string name = ge_node.GetName(); | |||||
const string type = ge_node.GetType(); | |||||
GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize()); | |||||
for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||||
GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast<uint32_t>(i)); | |||||
Format format = output_tensor.GetFormat(); | |||||
DataType data_type = output_tensor.GetDataType(); | |||||
int64_t mem_size = 0; | |||||
// If mem size has been set, no need reset. | |||||
if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) { | |||||
GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.", | |||||
name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size); | |||||
continue; | |||||
} | |||||
int64_t output_mem_size = 0; | |||||
GeShape output_shape = output_tensor.GetShape(); | |||||
if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) || | |||||
(output_mem_size < 0)) { | |||||
GELOGE(FAILED, | |||||
"[Calc][TensorMemSize] fail for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
REPORT_CALL_ERROR("E19999", | |||||
"CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
TensorUtils::SetSize(output_tensor, output_mem_size); | |||||
if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(i), output_tensor) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, | |||||
"[Update][OutputDesc] fail for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, | |||||
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
REPORT_CALL_ERROR("E19999", "UpdateOutputDesc failed for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, | |||||
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status HostCpuOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) { | |||||
// no need to generate device task | |||||
return SUCCESS; | |||||
} | |||||
} // namespace host_cpu | |||||
} // namespace ge |
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include "common/opskernel/ops_kernel_builder.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
class GE_FUNC_VISIBILITY HostCpuOpsKernelBuilder : public OpsKernelBuilder { | |||||
public: | |||||
Status Initialize(const map<std::string, std::string> &options) override; | |||||
Status Finalize() override; | |||||
Status CalcOpRunningParam(Node &node) override; | |||||
Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override; | |||||
}; | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ |
@@ -0,0 +1,67 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||||
#include <memory> | |||||
#include "common/constant/constant.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "op/op_factory.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
using domi::TaskDef; | |||||
using std::map; | |||||
using std::string; | |||||
using std::vector; | |||||
Status HostCpuOpsKernelInfoStore::Initialize(const map<string, string> &options) { | |||||
GELOGI("HostCpuOpsKernelInfoStore init start."); | |||||
OpInfo default_op_info = {.engine = kHostCpuEngineName, | |||||
.opKernelLib = kHostCpuOpKernelLibName, | |||||
.computeCost = 0, | |||||
.flagPartial = false, | |||||
.flagAsync = false, | |||||
.isAtomic = false}; | |||||
// Init op_info_map_ | |||||
auto all_ops = OpFactory::Instance().GetAllOps(); | |||||
for (auto &op : all_ops) { | |||||
op_info_map_[op] = default_op_info; | |||||
} | |||||
GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); | |||||
return SUCCESS; | |||||
} | |||||
Status HostCpuOpsKernelInfoStore::Finalize() { | |||||
op_info_map_.clear(); | |||||
return SUCCESS; | |||||
} | |||||
void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | |||||
bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { | |||||
if (op_desc == nullptr) { | |||||
return false; | |||||
} | |||||
return op_info_map_.count(op_desc->GetType()) > 0; | |||||
} | |||||
} // namespace host_cpu | |||||
} // namespace ge |
@@ -0,0 +1,86 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/opskernel/ops_kernel_info_store.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
class GE_FUNC_VISIBILITY HostCpuOpsKernelInfoStore : public OpsKernelInfoStore { | |||||
public: | |||||
HostCpuOpsKernelInfoStore() {} | |||||
~HostCpuOpsKernelInfoStore() override = default; | |||||
/** | |||||
* Initialize related resources of the host cpu kernelinfo store | |||||
* @return status whether this operation success | |||||
*/ | |||||
Status Initialize(const std::map<std::string, std::string> &options) override; | |||||
/** | |||||
* Release related resources of the host cpu kernel info store | |||||
* @return status whether this operation success | |||||
*/ | |||||
Status Finalize() override; | |||||
/** | |||||
* Check to see if an operator is fully supported or partially supported. | |||||
* @param op_desc OpDesc information | |||||
* @param reason unsupported reason | |||||
* @return bool value indicate whether the operator is fully supported | |||||
*/ | |||||
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; | |||||
/** | |||||
* Returns the full operator information. | |||||
* @param infos reference of a map, | |||||
* contain operator's name and detailed information | |||||
*/ | |||||
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override; | |||||
HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
private: | |||||
// store op name and OpInfo key-value pair | |||||
std::map<std::string, ge::OpInfo> op_info_map_; | |||||
}; | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "host_cpu_engine/ops_kernel_store/op/host_op.h" | |||||
#include "framework/common/util.h" | |||||
#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
Status HostOp::Run() { | |||||
// no need to generate device task | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_OP_CREATOR(NoOp, HostOp); | |||||
REGISTER_OP_CREATOR(Variable, HostOp); | |||||
REGISTER_OP_CREATOR(Constant, HostOp); | |||||
REGISTER_OP_CREATOR(Assign, HostOp); | |||||
REGISTER_OP_CREATOR(RandomUniform, HostOp); | |||||
REGISTER_OP_CREATOR(Add, HostOp); | |||||
REGISTER_OP_CREATOR(Mul, HostOp); | |||||
REGISTER_OP_CREATOR(ConcatV2, HostOp); | |||||
REGISTER_OP_CREATOR(Data, HostOp); | |||||
REGISTER_OP_CREATOR(Fill, HostOp); | |||||
REGISTER_OP_CREATOR(NetOutput, HostOp); | |||||
} // namespace host_cpu | |||||
} // namespace ge |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
#include "host_cpu_engine/ops_kernel_store/op/op.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
class GE_FUNC_VISIBILITY HostOp : public Op { | |||||
public: | |||||
HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} | |||||
~HostOp() override = default; | |||||
HostOp &operator=(const HostOp &op) = delete; | |||||
HostOp(const HostOp &op) = delete; | |||||
Status Run() override; | |||||
}; | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
#include <climits> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "common/opskernel/ops_kernel_info_types.h" | |||||
#include "graph/node.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
/** | |||||
* The base class for all op. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY Op { | |||||
public: | |||||
Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} | |||||
virtual ~Op() = default; | |||||
virtual Status Run() = 0; | |||||
protected: | |||||
const RunContext &run_context_; | |||||
const Node &node_; | |||||
}; | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ |
@@ -0,0 +1,55 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "graph/op_desc.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
OpFactory &OpFactory::Instance() { | |||||
static OpFactory instance; | |||||
return instance; | |||||
} | |||||
std::shared_ptr<Op> OpFactory::CreateOp(const Node &node, RunContext &run_context) { | |||||
auto iter = op_creator_map_.find(node.GetType()); | |||||
if (iter != op_creator_map_.end()) { | |||||
return iter->second(node, run_context); | |||||
} | |||||
GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) { | |||||
if (func == nullptr) { | |||||
GELOGW("Func is NULL."); | |||||
return; | |||||
} | |||||
auto iter = op_creator_map_.find(type); | |||||
if (iter != op_creator_map_.end()) { | |||||
GELOGW("%s creator already exist", type.c_str()); | |||||
return; | |||||
} | |||||
op_creator_map_[type] = func; | |||||
all_ops_.emplace_back(type); | |||||
} | |||||
} // namespace host_cpu | |||||
} // namespace ge |
@@ -0,0 +1,94 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
#include <functional> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/ge/ge_util.h" | |||||
#include "host_cpu_engine/ops_kernel_store/op/op.h" | |||||
namespace ge { | |||||
namespace host_cpu { | |||||
using OP_CREATOR_FUNC = std::function<std::shared_ptr<Op>(const Node &, RunContext &)>; | |||||
/** | |||||
* manage all the op, support create op. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY OpFactory { | |||||
public: | |||||
static OpFactory &Instance(); | |||||
/** | |||||
* @brief create Op. | |||||
* @param [in] node share ptr of node | |||||
* @param [in] run_context run context | |||||
* @return not nullptr success | |||||
* @return nullptr fail | |||||
*/ | |||||
std::shared_ptr<Op> CreateOp(const Node &node, RunContext &run_context); | |||||
/** | |||||
* @brief Register Op create function. | |||||
* @param [in] type Op type | |||||
* @param [in] func Op create func | |||||
*/ | |||||
void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func); | |||||
const std::vector<std::string> &GetAllOps() const { return all_ops_; } | |||||
bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); } | |||||
OpFactory(const OpFactory &) = delete; | |||||
OpFactory &operator=(const OpFactory &) = delete; | |||||
OpFactory(OpFactory &&) = delete; | |||||
OpFactory &operator=(OpFactory &&) = delete; | |||||
private: | |||||
OpFactory() = default; | |||||
~OpFactory() = default; | |||||
// the op creator function map | |||||
std::map<std::string, OP_CREATOR_FUNC> op_creator_map_; | |||||
std::vector<std::string> all_ops_; | |||||
}; | |||||
class GE_FUNC_VISIBILITY OpRegistrar { | |||||
public: | |||||
OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) { | |||||
OpFactory::Instance().RegisterCreator(type, func); | |||||
} | |||||
~OpRegistrar() = default; | |||||
OpRegistrar(const OpRegistrar &) = delete; | |||||
OpRegistrar &operator=(const OpRegistrar &) = delete; | |||||
OpRegistrar(OpRegistrar &&) = delete; | |||||
OpRegistrar &operator=(OpRegistrar &&) = delete; | |||||
}; | |||||
#define REGISTER_OP_CREATOR(type, clazz) \ | |||||
std::shared_ptr<Op> Creator_##type##Op(const Node &node, RunContext &run_context) { \ | |||||
return MakeShared<clazz>(node, run_context); \ | |||||
} \ | |||||
OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) | |||||
} // namespace host_cpu | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ |
@@ -0,0 +1,179 @@ | |||||
/* Copyright 2021. Huawei Technologies Co., Ltd. All rights reserved. | |||||
* | |||||
* This program is free software; you can redistribute it and/or modify | |||||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||||
* | |||||
* This program is distributed in the hope that it will be useful, | |||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||||
* Apache License for more details at | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
*/ | |||||
syntax = "proto3"; | |||||
package domi; | |||||
message ModelTaskDef { | |||||
string version = 1; | |||||
map<string, string> attr = 9; // Extended field | |||||
repeated TaskDef task = 10; | |||||
uint64 memory_size = 11; | |||||
uint32 stream_num = 12; | |||||
uint32 event_num = 13; | |||||
uint64 weight_size = 14; | |||||
repeated bytes op = 15; // input/output opdef in bytes | |||||
uint64 base_addr = 16; // base addr | |||||
uint64 weight_addr = 17; // weight addr | |||||
uint32 batch_num = 18; | |||||
} | |||||
message TaskDef { | |||||
uint32 id = 1; | |||||
uint32 type = 2; | |||||
uint32 stream_id = 10; | |||||
uint32 event_id = 11; | |||||
KernelDef kernel = 20; | |||||
KernelExDef kernel_ex = 21; | |||||
KernelHcclDef kernel_hccl = 25; | |||||
EventExDef event_ex = 26; | |||||
LogTimeStampDef log_timestamp = 28; | |||||
uint32 label_id = 30; | |||||
MemcpyAsyncDef memcpy_async = 31; | |||||
StreamSwitchDef stream_switch = 32; | |||||
StreamActiveDef stream_active = 33; | |||||
bytes private_def = 34; | |||||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||||
StreamSwitchNDef stream_switch_n = 36; | |||||
LabelSetDef label_set = 37; | |||||
LabelGotoExDef label_goto_ex = 38; | |||||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||||
KernelDefWithHandle kernel_with_handle = 40; | |||||
} | |||||
message KernelDef { | |||||
KernelContext context = 1; | |||||
string stub_func = 10; | |||||
uint32 block_dim = 11; | |||||
uint32 args_size = 12; | |||||
bytes args = 13; | |||||
bytes sm_desc = 14; | |||||
bytes flowtable = 15; | |||||
string so_name = 16; | |||||
string kernel_name = 17; | |||||
bytes kernel_ext_info = 18; | |||||
uint32 kernel_ext_info_size = 19; | |||||
} | |||||
message KernelDefWithHandle { | |||||
KernelContext context = 1; | |||||
uint64 handle = 10; | |||||
string dev_func = 11; | |||||
uint32 block_dim = 12; | |||||
uint32 args_size = 13; | |||||
bytes args = 14; | |||||
bytes sm_desc = 15; | |||||
string original_kernel_key = 16; | |||||
string node_info = 17; | |||||
} | |||||
message KernelContext { | |||||
uint32 kernel_type = 1; | |||||
uint32 op_id = 2; // OP type in CCE | |||||
uint32 kernel_func_id = 3; | |||||
uint32 op_index = 4; // TE/Custom operator | |||||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||||
bytes args_offset = 6; // args offset information | |||||
uint32 args_count = 7; // args count | |||||
repeated uint32 origin_op_index = 8; | |||||
} | |||||
message KernelExDef { | |||||
uint32 flags = 1; | |||||
uint32 op_index = 4; | |||||
uint32 args_size = 12; | |||||
bytes args = 13; | |||||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||||
uint32 task_info_size = 15; | |||||
bytes kernel_ext_info = 16; | |||||
uint32 kernel_ext_info_size = 17; | |||||
} | |||||
message KernelHcclDef { | |||||
uint32 op_index = 8; | |||||
string hccl_type = 9; | |||||
} | |||||
message EventExDef { | |||||
uint32 op_index = 1; | |||||
uint32 event_type = 2; | |||||
} | |||||
message LogTimeStampDef { | |||||
uint64 logid = 1; | |||||
bool notify = 2; | |||||
uint32 flat = 3; | |||||
} | |||||
message MemcpyAsyncDef { | |||||
uint64 dst = 1; | |||||
uint64 dst_max = 2; | |||||
uint64 src = 3; | |||||
uint64 count = 4; | |||||
uint32 kind = 5; | |||||
uint32 op_index = 6; | |||||
} | |||||
message StreamSwitchDef { | |||||
uint32 op_index = 1; | |||||
uint32 true_stream_id = 2; | |||||
int64 value = 3; | |||||
uint64 value_ptr = 4; | |||||
uint32 data_type = 5; | |||||
} | |||||
message StreamActiveDef { | |||||
uint32 op_index = 1; | |||||
uint32 active_stream_id = 2; | |||||
} | |||||
message StreamSwitchNDef { | |||||
uint32 op_index = 1; | |||||
uint32 size = 2; | |||||
repeated int64 target_value = 3; | |||||
repeated uint32 true_stream_id = 4; | |||||
uint32 element_size = 5; | |||||
uint32 data_type = 6; | |||||
} | |||||
message LabelSetDef { | |||||
uint32 op_index = 1; | |||||
uint32 label_id = 2; | |||||
uint32 model_id = 3; | |||||
} | |||||
message LabelGotoExDef { | |||||
uint32 op_index = 1; | |||||
uint32 label_id = 2; | |||||
uint32 model_id = 3; | |||||
} | |||||
message LabelSwitchByIndexDef { | |||||
uint32 op_index = 1; | |||||
uint32 label_max = 2; | |||||
} |
@@ -0,0 +1,711 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file array_ops.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||||
#include "graph/operator_reg.h" | |||||
#include "graph/operator.h" | |||||
namespace ge { | |||||
/** | |||||
*@brief Finds unique elements in a 1D tensor. \n | |||||
*@par Inputs: | |||||
*x: 1D tensor. | |||||
*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" | |||||
are 0D scalars. \n | |||||
*@par Attributes: | |||||
*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n | |||||
*@par Outputs: | |||||
*@li y: "x" in the unique output "y". | |||||
*@li idx: A tensor the same size as "x". The index of each value of "x". \n | |||||
*@attention Constraints: | |||||
*Unique runs on the Ascend AI CPU, which delivers poor performance. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Unique. | |||||
*/ | |||||
REG_OP(Unique) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||||
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||||
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||||
.OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) | |||||
.ATTR(out_idx, Type, DT_INT32) | |||||
.OP_END_FACTORY_REG(Unique) | |||||
/** | |||||
*@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. | |||||
Operator Const has the same definition as operator Constant. \n | |||||
*@par Attributes: | |||||
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||||
*@par Outputs: | |||||
*y: A constant tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Const. | |||||
*/ | |||||
REG_OP(Const) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(value, Tensor, Tensor()) | |||||
.OP_END_FACTORY_REG(Const) | |||||
/** | |||||
*@brief Creates a constant tensor for training. \n | |||||
*@par Attributes: | |||||
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||||
*@par Outputs: | |||||
*y: The constant tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Const. | |||||
*/ | |||||
REG_OP(Constant) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(value, Tensor, Tensor()) | |||||
.OP_END_FACTORY_REG(Constant) | |||||
/** | |||||
*@brief Returns a copy of the input tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Snapshot. | |||||
*/ | |||||
REG_OP(Snapshot) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(Snapshot) | |||||
/** | |||||
*@brief Gives a guarantee to the runtime that the input tensor is a constant. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: The input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator GuaranteeConst. | |||||
*/ | |||||
REG_OP(GuaranteeConst) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(GuaranteeConst) | |||||
/** | |||||
*@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n | |||||
*@par Inputs: | |||||
*@li x1: A tensor of type int32 or int64. A shape. | |||||
*@li x2: A tensor of the same type as "x1". The other shape. \n | |||||
*@par Outputs: | |||||
*y: A tensor. The broadcasted shape. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator BroadcastArgs. | |||||
*/ | |||||
REG_OP(BroadcastArgs) | |||||
.INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||||
.INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
.OP_END_FACTORY_REG(BroadcastArgs) | |||||
/** | |||||
*@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*message: Will be printed in the error at the attempt to request a gradient. \n | |||||
*@par Outputs: | |||||
*y: The input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator PreventGradient. | |||||
*/ | |||||
REG_OP(PreventGradient) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(message, String, "") | |||||
.OP_END_FACTORY_REG(PreventGradient) | |||||
/** | |||||
*@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n | |||||
*@par Inputs: | |||||
*@li x1: A tensor of type int32 or int64. | |||||
*@li x2: A tensor of type int32 or int64. | |||||
"x2" has the same type as "x1". \n | |||||
*@par Outputs: | |||||
*@li y1: A tensor. Reduction indices of "x1". | |||||
*@li y2: A tensor. Reduction indices of "x2". \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator BroadcastGradientArgs. | |||||
*/ | |||||
REG_OP(BroadcastGradientArgs) | |||||
.INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||||
.INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||||
.OUTPUT(y1, TensorType({DT_INT32, DT_INT64})) | |||||
.OUTPUT(y2, TensorType({DT_INT32, DT_INT64})) | |||||
.OP_END_FACTORY_REG(BroadcastGradientArgs) | |||||
/** | |||||
*@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped. | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: The input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator StopGradient. | |||||
*/ | |||||
REG_OP(StopGradient) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(StopGradient) | |||||
/** | |||||
*@brief Return a tensor with the same shape and contents as input. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Identity. | |||||
*/ | |||||
REG_OP(Identity) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(Identity) | |||||
/** | |||||
*@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n | |||||
*@par Inputs: | |||||
*x: A list of input tensors. It's a dynamic input \n | |||||
*@par Outputs: | |||||
*y: A list of Tensor objects, with the same length as the input tensor list. | |||||
It's a dynamic output. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator IdentityN. | |||||
*/ | |||||
REG_OP(IdentityN) | |||||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(IdentityN) | |||||
/** | |||||
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||||
*@par Inputs: | |||||
*@li x: A tensor. | |||||
*@li axis: The dimension index at which to expand. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator ExpandDims. | |||||
*/ | |||||
REG_OP(ExpandDims) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.INPUT(axis, TensorType({DT_INT32, DT_INT64})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OP_END_FACTORY_REG(ExpandDims) | |||||
/** | |||||
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||||
*@par Inputs: | |||||
*@li x: Original tensor. | |||||
*@li axis: List of ints. \n | |||||
*@par Outputs: | |||||
*y: Reshape tensor with same data as input. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the Onnx operator Unsqueeze. | |||||
*/ | |||||
REG_OP(Unsqueeze) | |||||
.INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||||
.ATTR(axes, ListInt, {}) | |||||
.OP_END_FACTORY_REG(Unsqueeze) | |||||
/** | |||||
*@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n | |||||
*@par Inputs: | |||||
*@li x: A tensor. | |||||
*@li shape: A tensor. Defines the shape of the output tensor. \n | |||||
*@par Attributes: | |||||
*@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0". | |||||
*@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Attention: | |||||
*This operator cannot be directly called by the acllopExecute API. \n | |||||
*@par Third-party framework compatibility | |||||
*@li Compatible with the TensorFlow operator Reshape. | |||||
*@li Compatible with the Caffe operator Reshape. | |||||
*/ | |||||
REG_OP(Reshape) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.INPUT(shape, TensorType({DT_INT32, DT_INT64})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(axis, Int, 0) | |||||
.ATTR(num_axes, Int, -1) | |||||
.OP_END_FACTORY_REG(Reshape) | |||||
/** | |||||
*@brief Removes dimensions of size 1 from the shape of a tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Squeeze. | |||||
*/ | |||||
REG_OP(Squeeze) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.ATTR(axis, ListInt, {}) | |||||
.OP_END_FACTORY_REG(Squeeze) | |||||
/** | |||||
*@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: A tensor. The rank of input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Rank. | |||||
*/ | |||||
REG_OP(Rank) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_INT32})) | |||||
.OP_END_FACTORY_REG(Rank) | |||||
/** | |||||
*@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||||
*@par Outputs: | |||||
*y: A tensor. The size of the input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Size. | |||||
*/ | |||||
REG_OP(Size) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_INT32,DT_INT64})) | |||||
.ATTR(dtype, Int, DT_INT32) | |||||
.OP_END_FACTORY_REG(Size) | |||||
/** | |||||
*@brief Input data for other operators. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*index: Index of the input tensor.The data type must be int32 or int64. | |||||
Assume that net has three data nodes, one should be set 0, another should | |||||
be set 1, and the left should be set 2. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the Caffe operator Data. | |||||
*/ | |||||
REG_OP(Data) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.ATTR(index, Int, 0) | |||||
.OP_END_FACTORY_REG(Data) | |||||
/** | |||||
*@brief Inserts a placeholder for a tensor that will be always fed. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*@li peerIndex: An integer type. The index of the corresponding "end" node connected to. | |||||
*@li parentId: A string, used to check if the nodes are from the saved parent node. | |||||
*@li parentOpType: A string. Op type of the original node. | |||||
*@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n | |||||
*@par Outputs: | |||||
*y: The created placeholder tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator PlaceHolder. | |||||
*/ | |||||
REG_OP(PlaceHolder) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to | |||||
.ATTR(parentId, String, "") // check if these node are from save parent node | |||||
.ATTR(parentOpType, String, "") // op type of original node | |||||
.ATTR(anchorIndex, Int, 0) // check if these node are from save anchor | |||||
.OP_END_FACTORY_REG(PlaceHolder) | |||||
/** | |||||
*@brief Inserts a placeholder with default value for a tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*@li dtype: data type of tensor. | |||||
*@li shape: tensor shape. \n | |||||
*@par Outputs: | |||||
*y: The created placeholder tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator PlaceholderWithDefault. | |||||
*/ | |||||
REG_OP(PlaceholderWithDefault) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.REQUIRED_ATTR(shape, ListInt) | |||||
.OP_END_FACTORY_REG(PlaceholderWithDefault) | |||||
/** | |||||
*@brief Reads and returns the value of the input variable tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator ReadVariableOp. | |||||
*/ | |||||
REG_OP(ReadVariableOp) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(dtype, Int, DT_INT32) | |||||
.OP_END_FACTORY_REG(ReadVariableOp) | |||||
/** | |||||
*@brief Mark outputs of one sub graph which partitioned by engine type. | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Attributes: | |||||
*@li peerIndex: The index of the corresponding 'placeholder' node it's connected to. | |||||
*@li parentOpType: Op type of original node. | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(End) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.ATTR(peerIndex, Int, 0) | |||||
.ATTR(parentOpType, String, "") | |||||
.OP_END_FACTORY_REG(End) | |||||
/** | |||||
*@brief Operations for writing summary data, for use in analysis and visualization. | |||||
*@par Inputs: | |||||
* One input: | |||||
*x: Collections of summary data. | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(Summary) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OP_END_FACTORY_REG(Summary) | |||||
/** | |||||
*@brief Returns the shape of a tensor. \n | |||||
*@par Inputs: | |||||
*x: A tensor. \n | |||||
*@par Attributes: | |||||
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||||
*@par Outputs: | |||||
*y: A tensor. The shape of the input tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Size. | |||||
*/ | |||||
REG_OP(Shape) | |||||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
.ATTR(dtype, Int, DT_INT32) | |||||
.OP_END_FACTORY_REG(Shape) | |||||
/** | |||||
*@brief Returns shape of tensors. \n | |||||
*@par Inputs: | |||||
*x: A list of input tensors. It's a dynamic input. \n | |||||
*@par Attributes: | |||||
*dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||||
*@par Outputs: | |||||
*y: A list of tensors with the same length as the input list of tensors. | |||||
It's a dynamic output. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator ShapeN. | |||||
*/ | |||||
REG_OP(ShapeN) | |||||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||||
.ATTR(dtype, Int, DT_INT32) | |||||
.OP_END_FACTORY_REG(ShapeN) | |||||
/** | |||||
*@brief Creates a tensor with the given "shape" and "dtype". \n | |||||
*@par Inputs: | |||||
*shape: The shape of the output tensor. \n | |||||
*@par Attributes: | |||||
*@li dtype: Optional. The data type of the output tensor. Defaults to "int32". | |||||
*@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n | |||||
*@par Outputs: | |||||
*y: A tensor. \n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Empty. | |||||
*/ | |||||
REG_OP(Empty) | |||||
.INPUT(shape, TensorType({DT_INT32})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
.ATTR(dtype, Int, DT_INT32) | |||||
.ATTR(init, Bool, 0) | |||||
.OP_END_FACTORY_REG(Empty) | |||||
/** | |||||
*@brief Returns locations of nonzero / true values in a tensor. \n | |||||
*@par Inputs: | |||||
*Including: | |||||
*x: A Tensor. Must be one of the following types: | |||||
DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, | |||||
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n | |||||
*@par Outputs: | |||||
*y: A Tensor of type DT_INT64. \n | |||||
*@attention Constraints: | |||||
*Where runs on the Ascend AI CPU, which delivers poor performance.\n | |||||
*@par Third-party framework compatibility | |||||
*Compatible with the TensorFlow operator Where. | |||||
*/ | |||||
REG_OP(Where) | |||||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ | |||||
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_INT64})) | |||||
.OP_END_FACTORY_REG(Where) | |||||
/** | |||||
*@brief Change the shape of output according to the attr outShape | |||||
* | |||||
*@par Inputs: | |||||
*x: A Tensor. \n | |||||
*@par Outputs: | |||||
*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n | |||||
*@par Attributes: | |||||
*outShape: The shape of output will be inferred according to the attribute | |||||
*/ | |||||
REG_OP(TransShape) | |||||
.INPUT(x, TensorType::ALL()) | |||||
.OUTPUT(y, TensorType::ALL()) | |||||
.ATTR(outShape,ListInt ,{}) | |||||
.OP_END_FACTORY_REG(TransShape); | |||||
/** | |||||
* @brief sort_v2. | |||||
* @par Inputs: | |||||
* @li x: An ND tensor of type float16. | |||||
* @par Attributes: | |||||
* @li axis: An optional int. The dimension to sort along. This value defaults to -1. | |||||
* @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False. | |||||
* @par Outputs: | |||||
* @li y: An ND tensor of type float16. | |||||
* @attention Constraints: | |||||
* @li Axis should select the last dim. | |||||
* @li When the sorting data is less than 150K, it is recommended to use this tbe ops, | |||||
and the descending performance is better than the ascending. | |||||
* @li The upper limit of data on Ascend910 is 2000K. | |||||
*/ | |||||
REG_OP(SortV2) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||||
.ATTR(axis, Int, -1) | |||||
.ATTR(descending, Bool, false) | |||||
.OP_END_FACTORY_REG(SortV2) | |||||
/** | |||||
* @brief Expand the input tensor to a compatible shape. \n | |||||
* @par Inputs: | |||||
* One inputs, including: | |||||
* @li x: A Tensor. Must be one of the following types: | |||||
* float16, float32, int32, int8 ,uint8. \n | |||||
* @li shape: A Tensor to specify the shape that the input tensor expanded to. \n | |||||
* @par Outputs: | |||||
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||||
* @par Third-party framework compatibility | |||||
* Compatible with the ONNX operator Expand. | |||||
*/ | |||||
REG_OP(Expand) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
.INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
.OP_END_FACTORY_REG(Expand) | |||||
/** | |||||
* @brief Expand the input tensor to a compatible shape. \n | |||||
* @par Inputs: | |||||
* One inputs, including: | |||||
* @li x: A Tensor. Must be one of the following types: | |||||
* float16, float32, int32, int8 ,uint8. \n | |||||
* @par Attributes: | |||||
* @li shape: A required listInt to specify the shape that the input tensor expanded to. \n | |||||
* @par Outputs: | |||||
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||||
* @par Third-party framework compatibility | |||||
* Compatible with the ONNX operator Expand. | |||||
*/ | |||||
REG_OP(ExpandD) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||||
.REQUIRED_ATTR(shape, ListInt) | |||||
.OP_END_FACTORY_REG(ExpandD) | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ |
@@ -0,0 +1,392 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file control_flow_ops.cpp | |||||
* \brief | |||||
*/ | |||||
#include "control_flow_ops.h" | |||||
#include "./util/common_shape_fns.h" | |||||
#include "./util/error_util.h" | |||||
#include "util/util.h" | |||||
namespace ge { | |||||
namespace { | |||||
graphStatus MergeInferImpl(Operator& op) { | |||||
TensorDesc td = op.GetOutputDesc("value_index"); | |||||
TensorDesc td_y = op.GetOutputDesc("y"); | |||||
td.SetShape(ge::Shape()); | |||||
td.SetDataType(DT_INT32); | |||||
auto ret = op.UpdateOutputDesc("value_index", td); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
// check N of "x" >= 1 | |||||
size_t in_num = op.GetInputsSize(); | |||||
if (in_num < 1) { | |||||
string reason = "inputs size[" + std::to_string(in_num) + "] must be greater than or equal to 1"; | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "input", reason); | |||||
return GRAPH_FAILED; | |||||
} else if (in_num == 2) { | |||||
// Check is loop_merge, order of InferShape: Enter->Merge->NextIteration | |||||
// So when processing InferShape on Merge op, shape & datatype of NextIteration op is set as default. | |||||
// Therefore, shape & datatype of Merge op should be set as the Enter op. | |||||
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||||
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||||
bool not_handle_flag0 = (x0_type == DT_FLOAT) && (x0_dims.size() == 0); | |||||
auto x1_type = op.GetDynamicInputDesc("x", 1).GetDataType(); | |||||
auto x1_dims = op.GetDynamicInputDesc("x", 1).GetShape().GetDims(); | |||||
bool not_handle_flag1 = (x1_type == DT_FLOAT) && (x1_dims.size() == 0); | |||||
if ((x0_type != x1_type) && (not_handle_flag0 || not_handle_flag1)) { | |||||
if (not_handle_flag0) { | |||||
td_y.SetShape(ge::Shape(x1_dims)); | |||||
td_y.SetDataType(x1_type); | |||||
} else { | |||||
td_y.SetShape(ge::Shape(x0_dims)); | |||||
td_y.SetDataType(x0_type); | |||||
} | |||||
(void)op.UpdateOutputDesc("y", td_y); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
// check "x" be same type | |||||
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||||
for (size_t i = 1; i < op.GetInputsSize(); i++) { | |||||
auto xi_type = op.GetDynamicInputDesc("x", i).GetDataType(); | |||||
if (xi_type != x0_type) { | |||||
string reason = "x[0]'s dtype[" + std::to_string(x0_type) + "] must be equal to x[" + std::to_string(i) + | |||||
"]'s dtype[" + std::to_string(xi_type) + "]"; | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
// infer "y" be unknown shape | |||||
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||||
bool x0_unknown = (x0_dims.size() == 1) && (x0_dims[0] == 0); | |||||
if (x0_unknown) { | |||||
Shape unknown_shape(ge::UNKNOWN_SHAPE); | |||||
td_y.SetShape(unknown_shape); | |||||
td_y.SetDataType(x0_type); | |||||
(void)op.UpdateOutputDesc("y", td_y); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
// find the input with the max size from all inputs, and set it's data type/shape to the output | |||||
std::map<int64_t, size_t> size_to_index; | |||||
for (size_t i = 0; i < op.GetInputsSize(); i++) { | |||||
auto xi_dims = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); | |||||
bool xi_unknown = (xi_dims.size() == 1) && (xi_dims[0] == 0); | |||||
if (xi_unknown) { | |||||
continue; | |||||
} | |||||
int64_t size = static_cast<int64_t>(GetSizeByDataType(op.GetDynamicInputDesc("x", i).GetDataType())); | |||||
if (size < 0) { | |||||
continue; | |||||
} | |||||
if (!xi_dims.empty()) { | |||||
for (auto& dim : xi_dims) { | |||||
if (dim <= 0) { | |||||
size = -1; | |||||
break; | |||||
} | |||||
if (size != 0 && INT64_MAX / size < dim) { | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dim", "the dim size is overflow"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
size *= dim; | |||||
} | |||||
if (size < 0) { | |||||
continue; | |||||
} | |||||
} | |||||
if (size_to_index.count(size) == 0) { | |||||
size_to_index[size] = i; | |||||
} | |||||
} | |||||
if (size_to_index.empty()) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto index = size_to_index.rbegin()->second; | |||||
td_y.SetShape(ge::Shape(op.GetDynamicInputDesc("x", index).GetShape().GetDims())); | |||||
td_y.SetDataType(op.GetDynamicInputDesc("x", index).GetDataType()); | |||||
(void)op.UpdateOutputDesc("y", td_y); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus SwitchInferImpl(Operator& op) { | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
auto data_desc = op_desc->MutableInputDesc("data"); | |||||
auto pred_desc = op_desc->MutableInputDesc("pred"); | |||||
auto output_false_desc = op_desc->MutableOutputDesc("output_false"); | |||||
auto output_true_desc = op_desc->MutableOutputDesc("output_true"); | |||||
std::vector<std::pair<int64_t, int64_t>> data_range; | |||||
data_desc->GetShapeRange(data_range); | |||||
// check "pred" scalar type be bool | |||||
auto pred_dims = pred_desc->GetShape().GetDims(); | |||||
if (pred_dims.size() != 0) { | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "pred dims", "pred should be a scalar"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
DataType pred_type = pred_desc->GetDataType(); | |||||
if (pred_type != DT_BOOL) { | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "pred should be bool type"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
DataType data_type = data_desc->GetDataType(); | |||||
auto data_dims = data_desc->GetShape().GetDims(); | |||||
output_false_desc->SetShapeRange(data_range); | |||||
output_true_desc->SetShapeRange(data_range); | |||||
output_false_desc->SetShape(GeShape(data_dims)); | |||||
output_false_desc->SetOriginShape(GeShape(data_dims)); | |||||
output_true_desc->SetShape(GeShape(data_dims)); | |||||
output_true_desc->SetOriginShape(GeShape(data_dims)); | |||||
output_false_desc->SetDataType(data_type); | |||||
output_true_desc->SetDataType(data_type); | |||||
auto context = op.GetInferenceContext(); | |||||
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||||
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||||
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||||
std::vector<ShapeAndType> grad_handle_shape_and_type; | |||||
grad_handle_shape_and_type.reserve(1); | |||||
grad_handle_shape_and_type.emplace_back(shape_and_type); | |||||
std::vector<std::vector<ShapeAndType>> shapes_and_types(2); | |||||
shapes_and_types[0] = grad_handle_shape_and_type; | |||||
shapes_and_types[1] = grad_handle_shape_and_type; | |||||
context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus EnterInferImpl(Operator& op) { | |||||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||||
auto input_desc_x = op_desc->MutableInputDesc("x"); | |||||
auto output_desc_y = op_desc->MutableOutputDesc("y"); | |||||
std::vector<std::pair<int64_t, int64_t>> x_range; | |||||
std::vector<std::pair<int64_t, int64_t>> y_range; | |||||
input_desc_x->GetShapeRange(x_range); | |||||
auto input_dims = input_desc_x->MutableShape().GetDims(); | |||||
DataType input_type = input_desc_x->GetDataType(); | |||||
output_desc_y->SetShape(ge::GeShape(input_dims)); | |||||
output_desc_y->SetOriginShape(ge::GeShape(input_dims)); | |||||
output_desc_y->SetDataType(input_type); | |||||
if (!x_range.empty()) { | |||||
output_desc_y->SetShapeRange(x_range); | |||||
} | |||||
auto context = op.GetInferenceContext(); | |||||
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||||
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||||
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||||
std::vector<ShapeAndType> grad_handle_shape_and_type; | |||||
grad_handle_shape_and_type.reserve(1); | |||||
grad_handle_shape_and_type.emplace_back(shape_and_type); | |||||
std::vector<std::vector<ShapeAndType>> shapes_and_types(1); | |||||
shapes_and_types[0] = grad_handle_shape_and_type; | |||||
context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus PassThroughInferImpl(Operator& op, const std::string& in_name, const std::string& out_name) { | |||||
auto input_dims = op.GetInputDesc(in_name).GetShape().GetDims(); | |||||
DataType input_type = op.GetInputDesc(in_name).GetDataType(); | |||||
TensorDesc tensordesc_output = op.GetOutputDesc(out_name); | |||||
tensordesc_output.SetShape(ge::Shape(input_dims)); | |||||
tensordesc_output.SetDataType(input_type); | |||||
(void)op.UpdateOutputDesc(out_name, tensordesc_output); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus LoopCondInferImpl(Operator& op) { | |||||
auto input_dims = op.GetInputDesc("x").GetShape().GetDims(); | |||||
if (input_dims.size() != 0) { | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "x dims", "x should be a scalar"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
TensorDesc tensordesc_output = op.GetOutputDesc("y"); | |||||
tensordesc_output.SetShape(ge::Shape(input_dims)); | |||||
DataType input_type = op.GetInputDesc("x").GetDataType(); | |||||
if (input_type != DT_BOOL) { | |||||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "x should be bool type"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
tensordesc_output.SetDataType(input_type); | |||||
(void)op.UpdateOutputDesc("y", tensordesc_output); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace | |||||
IMPLEMT_INFERFUNC(Merge, MergeInfer) { | |||||
return MergeInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(Merge, MergeInfer); | |||||
IMPLEMT_INFERFUNC(RefMerge, RefMergeInfer) { | |||||
return MergeInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(RefMerge, RefMergeInfer); | |||||
IMPLEMT_INFERFUNC(Switch, SwitchInfer) { | |||||
return SwitchInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(Switch, SwitchInfer); | |||||
IMPLEMT_INFERFUNC(RefSwitch, RefSwitchInfer) { | |||||
return SwitchInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(RefSwitch, RefSwitchInfer); | |||||
IMPLEMT_INFERFUNC(SwitchN, SwitchNInfer) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
INFER_FUNC_REG(SwitchN, SwitchNInfer); | |||||
IMPLEMT_INFERFUNC(Enter, EnterInfer) { | |||||
return EnterInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(Enter, EnterInfer); | |||||
IMPLEMT_INFERFUNC(RefEnter, RefEnterInfer) { | |||||
return PassThroughInferImpl(op, "x", "y"); | |||||
} | |||||
INFER_FUNC_REG(RefEnter, RefEnterInfer); | |||||
IMPLEMT_INFERFUNC(LoopCond, LoopCondInfer) { | |||||
return LoopCondInferImpl(op); | |||||
} | |||||
INFER_FUNC_REG(LoopCond, LoopCondInfer); | |||||
IMPLEMT_INFERFUNC(NextIteration, NextIterationInfer) { | |||||
return PassThroughInferImpl(op, "x", "y"); | |||||
} | |||||
INFER_FUNC_REG(NextIteration, NextIterationInfer); | |||||
IMPLEMT_INFERFUNC(RefNextIteration, RefNextIterationInfer) { | |||||
return PassThroughInferImpl(op, "x", "y"); | |||||
} | |||||
INFER_FUNC_REG(RefNextIteration, RefNextIterationInfer); | |||||
IMPLEMT_INFERFUNC(Exit, ExitInfer) { | |||||
return PassThroughInferImpl(op, "x", "y"); | |||||
} | |||||
INFER_FUNC_REG(Exit, ExitInfer); | |||||
IMPLEMT_INFERFUNC(RefExit, RefExitInfer) { | |||||
return PassThroughInferImpl(op, "x", "y"); | |||||
} | |||||
INFER_FUNC_REG(RefExit, RefExitInfer); | |||||
// ----------------MapIndex------------------- | |||||
IMPLEMT_VERIFIER(MapIndex, MapIndexVerify) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
IMPLEMT_COMMON_INFERFUNC(MapIndexInferShape) { | |||||
OP_LOGI("MapIndex", "infer shape begin---"); | |||||
auto x_shape = op.GetInputDesc("x").GetShape().GetDims(); | |||||
if (x_shape.empty()) { | |||||
OP_LOGE(op.GetName().c_str(), "x_shape is empty"); | |||||
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "x_shape is empty"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int64_t x_length = x_shape[0]; | |||||
auto data_seq_shape = op.GetInputDesc("data_seq").GetShape().GetDims(); | |||||
if (data_seq_shape.empty()) { | |||||
OP_LOGE(op.GetName().c_str(), "data_seq_shape is empty"); | |||||
OpsOneInputShapeErrReport(op.GetName().c_str(), "data_seq", "data_seq_shape is empty"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int64_t data_seq_length = data_seq_shape[0]; | |||||
if (x_length > 8 || x_length == 0) { | |||||
OP_LOGE(op.GetName().c_str(), "the length of x should be less than or equal to 8"); | |||||
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "the length of x should be less than or equal to 8 and not 0"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (data_seq_length % x_length != 0) { | |||||
OP_LOGE(op.GetName().c_str(), "the length of data_seq must be multiple of the length of x"); | |||||
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||||
"the length of data_seq must be multiple of the length of x"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (data_seq_length / x_length > 100) { | |||||
OP_LOGE(op.GetName().c_str(), "data_seq_length / x_length should be be less than or equal to 100"); | |||||
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||||
"data_seq_length / x_length should be be less than or equal to 100"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto level_index_shape = op.GetInputDesc("level_index").GetShape().GetDims(); | |||||
if (!level_index_shape.empty()) { | |||||
int64_t level_index_length = level_index_shape[0]; | |||||
if (level_index_length != (data_seq_length / x_length)) { | |||||
OP_LOGE(op.GetName().c_str(), | |||||
"the length of level_index must be equal to " | |||||
"the length of data_seq divided by the length of x"); | |||||
OpsOneInputShapeErrReport(op.GetName().c_str(), "level_index", | |||||
"the length of level_index must be equal to " | |||||
"the length of data_seq divided by the length of x"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
TensorDesc y_desc = op.GetOutputDesc("y"); | |||||
y_desc.SetShape(ge::Shape()); | |||||
y_desc.SetDataType(ge::DT_INT32); | |||||
(void)op.UpdateOutputDesc("y", y_desc); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
COMMON_INFER_FUNC_REG(MapIndex, MapIndexInferShape); | |||||
VERIFY_FUNC_REG(MapIndex, MapIndexVerify); | |||||
} // namespace ge |
@@ -0,0 +1,407 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file control_flow_ops.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||||
#include "graph/operator_reg.h" | |||||
#include "graph/operator.h" | |||||
namespace ge { | |||||
/** | |||||
*@brief Forwards the value of an available tensor from input "x" to output "y". | |||||
* Merge waits for at least one of the input tensors to become available. | |||||
* It is usually combined with Switch to implement branching. | |||||
* Merge forwards the first tensor to become available to output "y", | |||||
* and sets "value_index" the index of the tensor in inputs . \n | |||||
*@par Inputs: | |||||
*x: The input tensors, one of which will become available. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||||
*@par Outputs: | |||||
*@li y: The available tensor. Has the same type as "x". | |||||
*@li value_index: A scalar of type int32, for the index of the chosen input | |||||
* tensor . \n | |||||
*@see Switch() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator Merge. | |||||
*/ | |||||
REG_OP(Merge) | |||||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(value_index, TensorType({DT_INT32})) | |||||
.OP_END_FACTORY_REG(Merge) | |||||
/** | |||||
*@brief Forwards the value of an available tensor from input "x" to output "y". | |||||
* Merge waits for at least one of the input tensors to become available. | |||||
* It is usually combined with Switch to implement branching. | |||||
* Merge forwards the first tensor to become available to output "y", | |||||
* and sets "value_index" the index of the tensor in inputs . \n | |||||
*@par Inputs: | |||||
*x: The input tensors, one of which will become available. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||||
*@par Outputs: | |||||
*@li y: The available tensor. Has the same type as "x". | |||||
*@li value_index: A scalar of type int32, for the index of the chosen input | |||||
* tensor . \n | |||||
*@see Switch() | Merge() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator RefMerge. | |||||
*/ | |||||
REG_OP(RefMerge) | |||||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(value_index, TensorType({DT_INT32})) | |||||
.OP_END_FACTORY_REG(RefMerge) | |||||
/** | |||||
*@brief Forwards "data" to the output port determined by "pred". | |||||
* If "pred" is "true", the data input is forwarded to "output_true". | |||||
* Otherwise, the data is forwarded to "output_false" . \n | |||||
*@par Inputs: | |||||
*@li data: The tensor to be forwarded. \ n | |||||
* Must be one of the following types: float16, float32, float64, | |||||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
*@li pred: A boolean scalar. The output port that will receive data . \n | |||||
*@par Outputs: | |||||
*@li output_false: If "pred" is "false", data will be forwarded to this output. | |||||
* Has the same type as "data". | |||||
*@li output_true: If "pred" is "true", data will be forwarded to this output. | |||||
* Has the same type as "data" . \n | |||||
*@see Merge() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator Switch. | |||||
*/ | |||||
REG_OP(Switch) | |||||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.INPUT(pred, TensorType({DT_BOOL})) | |||||
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(Switch) | |||||
/** | |||||
*@brief Forwards "data" to the output port determined by "pred". | |||||
* If "pred" is "true", the data input is forwarded to "output_true". | |||||
* Otherwise, the data is forwarded to "output_false" . \n | |||||
*@par Inputs: | |||||
*@li data: The ref tensor to be forwarded. | |||||
* Must be one of the following types: float16, float32, float64, | |||||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
*@li pred: A boolean scalar. The output port that will receive data . \n | |||||
*@par Outputs: | |||||
*@li output_false: If "pred" is "false", data will be forwarded to this output. | |||||
* Has the same type as "data". | |||||
*@li output_true: If "pred" is "true", data will be forwarded to this output. | |||||
* Has the same type as "data" . \n | |||||
*@see Merge() | Switch() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator RefSwitch. | |||||
*/ | |||||
REG_OP(RefSwitch) | |||||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.INPUT(pred, TensorType({DT_BOOL})) | |||||
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(RefSwitch) | |||||
/** | |||||
*@brief Forwards "data" to the output port determined by "pred_value" . \n | |||||
*@par Inputs: | |||||
*@li data: The tensor to be forwarded. \ n | |||||
* Must be one of the following types: float16, float32, float64, | |||||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||||
*@li pred_value: A int64 tensor which determines the output port that will receive data . \n | |||||
*@par Outputs: | |||||
*output: The output tensors, one of which will become available. | |||||
* Has the same type as "data". | |||||
*/ | |||||
REG_OP(SwitchN) | |||||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.INPUT(pred_value, TensorType({DT_INT64})) | |||||
.DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(SwitchN) | |||||
/** | |||||
*@brief Creates or finds a child frame, and makes "x" available to the child | |||||
* frame. This op is used together with Exit to create loops in the graph. | |||||
* The Executor uses the unique "frame_name" to identify frames. | |||||
* If "is_constant" is "true", output "y" is a constant in the child | |||||
* frame; otherwise it may be changed in the child frame . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the child frame. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Attributes: | |||||
*@li frame_name: A required string. The name of the child frame. | |||||
*@li is_constant: A required bool. If true, the output is constant in | |||||
* the child frame . \n | |||||
*@par Outputs: | |||||
*y: A Tensor. Has the same type as "x" . \n | |||||
*@see Exit() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator Enter. | |||||
*/ | |||||
REG_OP(Enter) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.REQUIRED_ATTR(frame_name, String) | |||||
.REQUIRED_ATTR(is_constant, Bool) | |||||
.OP_END_FACTORY_REG(Enter) | |||||
/** | |||||
*@brief Creates or finds a child frame, and makes "x" available to the child | |||||
* frame. This op is used together with Exit to create loops in the graph. | |||||
* The Executor uses the unique "frame_name" to identify frames. | |||||
* If "is_constant" is "true", output "y" is a constant in the child | |||||
* frame; otherwise it may be changed in the child frame . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the child frame. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Attributes: | |||||
*@li frame_name: A required string. The name of the child frame. | |||||
*@li is_constant: A required bool. If true, the output is constant in | |||||
* the child frame . \n | |||||
*@par Outputs: | |||||
*y: A tensor. Has the same type as "x" . \n | |||||
*@see Exit() | Enter() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator RefEnter. | |||||
*/ | |||||
REG_OP(RefEnter) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.REQUIRED_ATTR(frame_name, String) | |||||
.REQUIRED_ATTR(is_constant, Bool) | |||||
.OP_END_FACTORY_REG(RefEnter) | |||||
/** | |||||
*@brief Forwards the input to the output. This op represents the loop | |||||
* termination condition . \n | |||||
*@par Inputs: | |||||
*x: A boolean scalar. The condition of the Switch op . \n | |||||
*@par Outputs: | |||||
*y: The tensor "x" . \n | |||||
*@see Switch() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator LoopCond. | |||||
*/ | |||||
REG_OP(LoopCond) | |||||
.INPUT(x, TensorType({DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_BOOL})) | |||||
.OP_END_FACTORY_REG(LoopCond) | |||||
/** | |||||
*@brief Makes the input available to the next iteration . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the next iteration. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Outputs: | |||||
*y: A Tensor. Has the same type as "x" . \n | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator NextIteration. | |||||
*/ | |||||
REG_OP(NextIteration) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(NextIteration) | |||||
/** | |||||
*@brief Makes the input available to the next iteration . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the next iteration. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Outputs: | |||||
*y: A tensor. Has the same type as "x" . \n | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator RefNextIteration. | |||||
*/ | |||||
REG_OP(RefNextIteration) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(RefNextIteration) | |||||
/** | |||||
*@brief Exits the current frame to its parent frame . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the parent frame. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Outputs: | |||||
*y: A Tensor. Has the same type as "x" . \n | |||||
*@see Enter() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator Exit. | |||||
*/ | |||||
REG_OP(Exit) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(Exit) | |||||
/** | |||||
*@brief Exits the current frame to its parent frame . \n | |||||
*@par Inputs: | |||||
*x: The tensor to be made available to the parent frame. | |||||
* Must be one of the following types: float16, float32, float64, int8, | |||||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||||
*@par Outputs: | |||||
*y: A tensor. Has the same type as "x" . \n | |||||
*@see Enter() | Exit() | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator RefExit. | |||||
*/ | |||||
REG_OP(RefExit) | |||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||||
DT_UINT64, DT_BOOL})) | |||||
.OP_END_FACTORY_REG(RefExit) | |||||
/** | |||||
*@brief Only useful as a placeholder for control edges. | |||||
* It is similar to a no-op that always produces a live control output | |||||
* even when some control inputs are dead . \n | |||||
*@par Third-party framework compatibility | |||||
*@Compatible with the TensorFlow operator ControlTrigger. | |||||
*/ | |||||
REG_OP(ControlTrigger) | |||||
.OP_END_FACTORY_REG(ControlTrigger) | |||||
/** | |||||
*@brief Returns index of shape in the map. | |||||
*@par Inputs: | |||||
* Three inputs, including: | |||||
*@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8. | |||||
*@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried. | |||||
*@li level_index: One dimensional tensore of type int32, specifying secondary index. \n | |||||
*@par Outputs: | |||||
*@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map. | |||||
*@par Third-party framework compatibility | |||||
* It is a custom operator. It has no corresponding operator in Caffe. | |||||
*/ | |||||
REG_OP(MapIndex) | |||||
.INPUT(x, TensorType({DT_INT32})) | |||||
.INPUT(data_seq, TensorType({DT_INT32})) | |||||
.OPTIONAL_INPUT(level_index, TensorType({DT_INT32})) | |||||
.OUTPUT(y, TensorType({DT_INT32})) | |||||
.OP_END_FACTORY_REG(MapIndex) | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ |
@@ -0,0 +1,234 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file array_ops_shape_fns.cpp | |||||
* \brief | |||||
*/ | |||||
#include "array_ops_shape_fns.h" | |||||
#include "graph/types.h" | |||||
#include "op_log.h" | |||||
#include "error_util.h" | |||||
#include "common_shape_fns.h" | |||||
#include "axis_util.h" | |||||
namespace ge { | |||||
static graphStatus PadKnown(Operator& op, const Tensor& paddings_tensor, const int64_t input_dim_num) { | |||||
TensorDesc paddings_tensor_desc = paddings_tensor.GetTensorDesc(); | |||||
DataType data_type = paddings_tensor_desc.GetDataType(); | |||||
std::vector<int64_t> data; | |||||
// every dim has 2 element | |||||
int64_t element_num = input_dim_num * 2; | |||||
data.reserve(element_num); | |||||
if (data_type == DT_INT32) { | |||||
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||||
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < element_num, | |||||
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||||
for (int64_t i = 0; i < element_num; ++i) { | |||||
data.push_back(static_cast<int64_t>(paddings_data[i])); | |||||
} | |||||
} else if (data_type == DT_INT64) { | |||||
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||||
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < element_num, | |||||
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||||
for (int64_t i = 0; i < element_num; ++i) { | |||||
data.push_back(paddings_data[i]); | |||||
} | |||||
} else { | |||||
string err_msg = ConcatString("paddings data type invalid, ", "should be DT_INT32 or DT_INT64"); | |||||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto dims = op.GetInputDesc(0).GetShape().GetDims(); | |||||
std::vector<int64_t> output_dims(input_dim_num, UNKNOWN_DIM); | |||||
if (dims != UNKNOWN_SHAPE) { | |||||
output_dims.assign(dims.begin(), dims.end()); | |||||
} | |||||
for (size_t i = 0; i < data.size(); i += 2) { | |||||
if ((data[i] < 0) || (data[i + 1] < 0)) { | |||||
std::string err_msg = ConcatString("paddings", DebugString(data), " must be non-negative"); | |||||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus status = Add(output_dims[i / 2], data[i] + data[i + 1], output_dims[i / 2]); | |||||
if (status != GRAPH_SUCCESS) { | |||||
std::string err_msg = ConcatString("the sum input[0] shape", DebugString(dims), " and input[1] value", | |||||
DebugString(data), " must be non-negative"); | |||||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
auto output_desc = op.GetOutputDesc("y"); | |||||
output_desc.SetShape(Shape(output_dims)); | |||||
return op.UpdateOutputDesc("y", output_desc); | |||||
} | |||||
graphStatus PadShapeFn(Operator& op) { | |||||
Shape paddings; | |||||
int64_t input_dim_num; | |||||
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||||
if (status != GRAPH_SUCCESS) { | |||||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
status = WithValue(paddings.GetDim(1), 2, input_dim_num, op.GetName().c_str()); | |||||
if (status != GRAPH_SUCCESS) { | |||||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), | |||||
ConcatString(2, " of dim[1]")); | |||||
return GRAPH_FAILED; | |||||
} | |||||
Shape input; | |||||
int64_t dim0 = paddings.GetDim(0); | |||||
if (dim0 != UNKNOWN_DIM) { | |||||
status = WithRank(op.GetInputDesc(0), dim0, input, op.GetName().c_str()); | |||||
if (status != GRAPH_SUCCESS) { | |||||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} else if (op.GetInputDesc(0).GetShape().GetDim(0) != 0) { | |||||
status = WithValue(dim0, op.GetInputDesc(0).GetShape().GetDimNum(), input_dim_num, op.GetName().c_str()); | |||||
if (status != GRAPH_SUCCESS) { | |||||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
TensorDesc output_desc = op.GetOutputDesc("y"); | |||||
Tensor paddings_tensor; | |||||
status = op.GetInputConstData("paddings", paddings_tensor); | |||||
if (status != GRAPH_SUCCESS) { | |||||
if (dim0 != UNKNOWN_DIM) { | |||||
std::vector<int64_t> output_shape(dim0, UNKNOWN_DIM); | |||||
output_desc.SetShape(Shape(output_shape)); | |||||
} else { | |||||
output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||||
} | |||||
return op.UpdateOutputDesc("y", output_desc); | |||||
} | |||||
input_dim_num = paddings_tensor.GetTensorDesc().GetShape().GetDim(0); | |||||
status = WithRank(op.GetInputDesc(0), input_dim_num, input, op.GetName().c_str()); | |||||
if (status == GRAPH_FAILED) { | |||||
OP_LOGE(op.GetName().c_str(), "WithRank fail"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
status = WithValue(dim0, input_dim_num, dim0, op.GetName().c_str()); | |||||
if (status == GRAPH_FAILED) { | |||||
OP_LOGE(op.GetName().c_str(), "WithValue fail"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return PadKnown(op, paddings_tensor, input_dim_num); | |||||
} | |||||
static graphStatus CalcPadGradOutDims(const Shape& input_shape, const Tensor& paddings_tensor, | |||||
std::vector<int64_t>& output_dims, const char* op_name) { | |||||
graphStatus status; | |||||
size_t input_rank = input_shape.GetDimNum(); | |||||
if (output_dims.size() < input_rank) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
DataType padding_type = paddings_tensor.GetTensorDesc().GetDataType(); | |||||
if (padding_type == DT_INT32) { | |||||
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||||
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < input_rank, | |||||
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||||
for (size_t i = 0; i < input_rank; ++i) { | |||||
const int64_t pad0 = static_cast<int64_t>(paddings_data[2 * i]); | |||||
const int64_t pad1 = static_cast<int64_t>(paddings_data[(2 * i) + 1]); | |||||
if ((pad0 < 0) || (pad1 < 0)) { | |||||
OP_LOGE(op_name, "Paddings must be non-negative, pad0= %lld, pad1=%lld.", pad0, pad1); | |||||
return GRAPH_FAILED; | |||||
} | |||||
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||||
if (status != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} else if (padding_type == DT_INT64) { | |||||
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||||
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < input_rank, | |||||
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||||
for (size_t i = 0; i < input_rank; ++i) { | |||||
const int64_t pad0 = paddings_data[2 * i]; | |||||
const int64_t pad1 = paddings_data[(2 * i) + 1]; | |||||
if ((pad0 < 0) || (pad1 < 0)) { | |||||
OP_LOGE(op_name, "Paddings must be non-negative, pad0=%lld, pad1=%lld.", pad0, pad1); | |||||
return GRAPH_FAILED; | |||||
} | |||||
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||||
if (status != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} else { | |||||
OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus PadGradShapeFn(Operator& op) { | |||||
Shape paddings; | |||||
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||||
if (status != GRAPH_SUCCESS) { | |||||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
int64_t input_rank = paddings.GetDim(0); | |||||
TensorDesc output_desc = op.GetOutputDesc("y"); | |||||
output_desc.SetDataType(op.GetInputDesc(0).GetDataType()); | |||||
if (input_rank == UNKNOWN_DIM) { | |||||
OP_LOGE(op.GetName().c_str(), "paddings inputShape of 0 dims is unknown, set out shape unknown."); | |||||
output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||||
return op.UpdateOutputDesc("y", output_desc); | |||||
} | |||||
Shape input_shape; | |||||
if (WithRank(op.GetInputDesc(0), input_rank, input_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { | |||||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(input_rank)); | |||||
return GRAPH_FAILED; | |||||
} | |||||
Shape check_shape({input_rank, 2}); | |||||
if (Merge(paddings, check_shape, paddings, op.GetName().c_str())) { | |||||
string err_msg = ConcatString("merge 1th input shape", DebugString(paddings.GetDims()), " and shape", | |||||
DebugString(check_shape.GetDims()), " failed"); | |||||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
OP_LOGE(op.GetName().c_str(), "Input dimension mismatch, inputRank=%lld.", input_rank); | |||||
return GRAPH_FAILED; | |||||
} | |||||
Tensor paddings_tensor; | |||||
if (op.GetInputConstData("paddings", paddings_tensor) != GRAPH_SUCCESS) { | |||||
std::vector<int64_t> unknow_dim_vec(input_rank, UNKNOWN_DIM); | |||||
OP_LOGE(op.GetName().c_str(), "Get paddings input tensor fail, set outPut shape unknown."); | |||||
output_desc.SetShape(Shape(unknow_dim_vec)); | |||||
return op.UpdateOutputDesc("y", output_desc); | |||||
} | |||||
std::vector<int64_t> output_dims(input_rank); | |||||
auto result = CalcPadGradOutDims(input_shape, paddings_tensor, output_dims, op.GetName().c_str()); | |||||
if (result != GRAPH_SUCCESS) { | |||||
string err_msg = ConcatString("calculate out dims failed,", "please check the validity of input and attribute"); | |||||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||||
OP_LOGE(op.GetName().c_str(), "Calculation PadGrad out dimensions failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
output_desc.SetShape(Shape(output_dims)); | |||||
return op.UpdateOutputDesc("y", output_desc); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file array_ops_shape_fns.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||||
#include "graph/operator.h" | |||||
namespace ge { | |||||
/* * | |||||
* infer pad op shape | |||||
* @param op Operator which need to infershape | |||||
* @return status whether infershape success | |||||
*/ | |||||
graphStatus PadShapeFn(Operator& op); | |||||
/* * | |||||
* infer pad grad op shape | |||||
* @param op Operator which need to infershape | |||||
* @return status whether infershape success | |||||
*/ | |||||
graphStatus PadGradShapeFn(Operator& op); | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ |
@@ -0,0 +1,195 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file axis_util.cpp | |||||
* \brief get the axis value | |||||
*/ | |||||
#include "axis_util.h" | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "framework/common/types.h" | |||||
namespace ge { | |||||
AxisUtil::AxisUtil() { | |||||
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)}, | |||||
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)}, | |||||
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)}, | |||||
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)}, | |||||
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)}, | |||||
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}}; | |||||
} | |||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { | |||||
if (divisor == 0) { | |||||
return 0; | |||||
} else { | |||||
return (dividend + divisor - 1) / divisor; | |||||
} | |||||
} | |||||
bool AxisUtil::GetAxisValueByOriginFormat(const Format& format, const vector<int64_t>& dimVec, const uint32_t& c0, | |||||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
LOG_INFO("Can not get axis value of old format %u!", format); | |||||
return false; | |||||
} | |||||
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; | |||||
CHECK_NOTNULL(getAxisFunc); | |||||
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); | |||||
} | |||||
bool AxisUtil::HasAxisValueFunc(const Format& format) { | |||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
LOG_INFO("Can not get axis value of format %u!", format); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::CheckParams(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
vector<int64_t>& ndValue) { | |||||
ndValue = originalDimVec; | |||||
auto dimSize = originalDimVec.size(); | |||||
if (dimSize < ge::DIM_DEFAULT_SIZE) { | |||||
/* Before this funcion, we should call function PadDimensionTo4. */ | |||||
LOG_INFO("Dimension size %zu is invalid.", dimSize); | |||||
return false; | |||||
} | |||||
if (c0 == 0) { | |||||
LOG_ERROR("[ERROR]c0 is zero!"); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByND(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
ndValue = originalDimVec; | |||||
/* To differentiate the input datatype of int8 and others */ | |||||
axisValue[AXIS_C0] = c0; | |||||
if (originalDimVec.size() == NCHW_DIMENSION_NUM) { | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NCHW to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
auto dimSize = originalDimVec.size(); | |||||
if (dimSize == ge::DIM_DEFAULT_SIZE + 1) { | |||||
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; | |||||
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; | |||||
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; | |||||
} else { | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_C0] = c0; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
} | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||||
vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; | |||||
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; | |||||
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; | |||||
return true; | |||||
} | |||||
}; // namespace ge |
@@ -0,0 +1,144 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file axis_util.h | |||||
* \brief get the axis value | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||||
#include <memory.h> | |||||
#include <functional> | |||||
#include <vector> | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "operator.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "op_log.h" | |||||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||||
#define LOG_INFO(format, args...) printf(format, ##args) | |||||
namespace ge { | |||||
const uint32_t NCHW_DIMENSION_NUM = 4; | |||||
const int32_t AXIS_NCHW_DIM_N = 0; | |||||
const int32_t AXIS_NCHW_DIM_C = 1; | |||||
const int32_t AXIS_NCHW_DIM_H = 2; | |||||
const int32_t AXIS_NCHW_DIM_W = 3; | |||||
const int32_t AXIS_NHWC_DIM_N = 0; | |||||
const int32_t AXIS_NHWC_DIM_H = 1; | |||||
const int32_t AXIS_NHWC_DIM_W = 2; | |||||
const int32_t AXIS_NHWC_DIM_C = 3; | |||||
const int32_t AXIS_NC1HWC0_DIM_N = 0; | |||||
const int32_t AXIS_NC1HWC0_DIM_C1 = 1; | |||||
const int32_t AXIS_NC1HWC0_DIM_C0 = 4; | |||||
const int32_t AXIS_NC1HWC0_DIM_H = 2; | |||||
const int32_t AXIS_NC1HWC0_DIM_W = 3; | |||||
const int32_t AXIS_HWCN_DIM_H = 0; | |||||
const int32_t AXIS_HWCN_DIM_W = 1; | |||||
const int32_t AXIS_HWCN_DIM_C = 2; | |||||
const int32_t AXIS_HWCN_DIM_N = 3; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_H = 1; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_W = 2; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_N = 3; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; | |||||
#define CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if ((val) == nullptr) { \ | |||||
LOG_ERROR("[ERROR]Parameter[%s] must not be null.", #val); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | |||||
#define CHECK(cond, log_func, return_expr) \ | |||||
do { \ | |||||
if (cond) { \ | |||||
log_func; \ | |||||
return_expr; \ | |||||
} \ | |||||
} while (0) | |||||
enum AxisValueType { | |||||
AXIS_N = 0, | |||||
AXIS_C = 1, | |||||
AXIS_H = 2, | |||||
AXIS_W = 3, | |||||
AXIS_C1 = 4, | |||||
AXIS_C0 = 5, | |||||
AXIS_Co = 6, | |||||
AXIS_D = 7, | |||||
AXIS_BOTTOM = 8 | |||||
}; | |||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor); | |||||
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */ | |||||
/* The first parameter is old shape's dimension, | |||||
* second is c0 and third is axis value. */ | |||||
using GetAxisValueInfoByFormat = | |||||
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>; | |||||
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>; | |||||
class AxisUtil { | |||||
public: | |||||
AxisUtil(); | |||||
~AxisUtil(){}; | |||||
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
bool HasAxisValueFunc(const ge::Format& format); | |||||
private: | |||||
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
* formats. */ | |||||
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap; | |||||
}; | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ |
@@ -0,0 +1,417 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file common_shape_fns.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||||
#include <string> | |||||
#include <vector> | |||||
#include "graph/tensor.h" | |||||
#include "graph/operator.h" | |||||
#include "graph/op_desc.h" | |||||
#include "graph/ge_tensor.h" | |||||
#include "error_code.h" | |||||
namespace ge { | |||||
/** | |||||
* Check whether Shape's rank is at least rank | |||||
* @param tensor Input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out Output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
/** | |||||
* Check whether Shape's rank is at least rank | |||||
* @param tensor Input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out Output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
/** | |||||
* Check whether Shape's rank is equal to rank | |||||
* @param tensor Input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out Output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
/** | |||||
* Check whether Shape's rank is equal to rank | |||||
* @param tensor Input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out Output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
/** | |||||
* Check whether Shape's rank is equal to rank | |||||
* @param tensor Input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out Output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape); | |||||
/** | |||||
* Check whether dim is equal to value | |||||
* @param dim Input dim | |||||
* @param value expect val of dim | |||||
* @param out Output dim | |||||
* @return status whether Dim is equal to value | |||||
*/ | |||||
graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name); | |||||
/** | |||||
* Merge two dims of Shape | |||||
* @param dim0 first dim val | |||||
* @param dim1 second dim val | |||||
* @param out merged dim val | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out); | |||||
/** | |||||
* Merge two shapes | |||||
* @param s0 first shape val | |||||
* @param s1 second shape val | |||||
* @param out merged shape val | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name); | |||||
/** | |||||
* Merge two shapes | |||||
* @param s0 first Geshape val | |||||
* @param s1 second Geshape val | |||||
* @param out merged Geshape val | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name); | |||||
/** | |||||
* Replace one dim in a given shape | |||||
* @param s original shape | |||||
* @param dim_index_in dim index | |||||
* @param new_dim new dim value | |||||
* @param out new shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name); | |||||
/** | |||||
* Replace one dim in a given shape | |||||
* @param s original shape | |||||
* @param dim_index_in dim index | |||||
* @param new_dim new dim value | |||||
* @param out new shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name); | |||||
/** | |||||
* Check if it satisfies 0 <= index < limit | |||||
* @param index first input | |||||
* @param limit second input | |||||
* @return status whether this operation success | |||||
*/ | |||||
template <typename Ta, typename Tb> | |||||
bool FastBoundsCheck(const Ta index, const Tb limit); | |||||
/** | |||||
* Add two dims | |||||
* @param dim0 first dim val | |||||
* @param dim1 second dim val | |||||
* @param out sum dim val | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out); | |||||
/** | |||||
* Subtract two dims | |||||
* @param dim0 first dim val | |||||
* @param dim1 second dim val | |||||
* @param out Subtract dim val | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name); | |||||
/** | |||||
* Get SubShape according to start end index and step size stride | |||||
* @param s input Shape | |||||
* @param start sub start index | |||||
* @param end sub end index | |||||
* @param stride sub step size | |||||
* @param out sub shape output | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name); | |||||
/** | |||||
* Get SubShape according to start end index and step size stride | |||||
* @param s input Shape | |||||
* @param start sub start index | |||||
* @param end sub end index | |||||
* @param stride sub step size | |||||
* @param out sub shape output | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus SubShape(const GeShape& s, size_t start, size_t end, size_t stride, GeShape& out); | |||||
/** | |||||
* Get SubShape according to start end index and step size stride | |||||
* @param s input Shape | |||||
* @param start sub start index | |||||
* @param end sub end index | |||||
* @param stride sub step size | |||||
* @param out sub shape output | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus SubShape(const GeShape& s, int64_t start, int64_t end, int64_t stride, GeShape& out, const char* op_name); | |||||
/** | |||||
* Concatenate two shape | |||||
* @param s1 first shape | |||||
* @param s2 second shape | |||||
* @param out concatenated shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out); | |||||
/** | |||||
* Concatenate two shape | |||||
* @param s1 first shape | |||||
* @param s2 second shape | |||||
* @param out concatenated shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out); | |||||
/** | |||||
* Gen matrix shape according d1 and d2 | |||||
* @param dim1 first dim val | |||||
* @param dim2 first dim val | |||||
* @param out matrix shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out); | |||||
/** | |||||
* Gen vector shape according d | |||||
* @param dim dim val | |||||
* @param out vector shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Vector(int64_t dim, Shape& out); | |||||
/** | |||||
* Make shape from shape tensor | |||||
* @param tensor shape tensor | |||||
* @param out shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name); | |||||
/** | |||||
* Make shape from shape tensor | |||||
* @param op Operator | |||||
* @param dst_name const string & | |||||
* @param out GeShape | |||||
* @param op_name const char * | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name); | |||||
/** | |||||
* Make dim from scalar tensor | |||||
* @param tensor shape tensor | |||||
* @param out shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name); | |||||
/** | |||||
* Check whether Shape's rank is at most rank | |||||
* @param tensor input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||||
/** | |||||
* Check whether Shape's rank is at most rank | |||||
* @param tensor input tensor | |||||
* @param rank expect val of Shape | |||||
* @param out output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||||
/** | |||||
* make a empty dim shape | |||||
* @param out output Shape | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus Scalar(Shape& out); | |||||
/** | |||||
* set input_name shape to output_name shape | |||||
* @param op Operator which need to infershape | |||||
* @param input_name input name of Operator | |||||
* @param output_name ouput name of Operator | |||||
* @return status whether infershape success | |||||
*/ | |||||
graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name); | |||||
/** | |||||
* Devide dim | |||||
* @param dividend | |||||
* @param divisor | |||||
* @param evenlyDivisible if to be divisible | |||||
* @param out dims | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out, | |||||
const char* op_name); | |||||
/** | |||||
* check shape fully defined or not | |||||
* @param shape Shape is checked | |||||
* @return whether shape is fully defined | |||||
*/ | |||||
bool ShapeFullDefined(const Shape& shape); | |||||
/** | |||||
* check shape fully defined or not | |||||
* @param shape Shape is checked | |||||
* @return whether shape is fully defined | |||||
*/ | |||||
bool ShapeFullyDefined(const GeShape& shape); | |||||
/** | |||||
* check shape known or not | |||||
* @param shape Shape is checked | |||||
* @return whether rank is known | |||||
*/ | |||||
bool RankKnown(const Shape& shape); | |||||
/** | |||||
* check ge_shape known or not | |||||
* @param shape GeShape is checked | |||||
* @return whether rank is known | |||||
*/ | |||||
bool RankKnown(const GeShape& shape); | |||||
/** | |||||
* make a unknown shape with rank | |||||
* @return unknown shape | |||||
*/ | |||||
Shape UnknownShapeOfRank(int64_t rank); | |||||
/** | |||||
* check dim value known or not | |||||
* @param shape which Shape need check dim value | |||||
* @param dimIndex the index of dim | |||||
* @return whether dim value is known | |||||
*/ | |||||
bool ValueKnown(const Shape& shape, const size_t& dim_index); | |||||
/** | |||||
* Validates the 3 component tensors of a sparse tensor | |||||
* have the proper shapes. | |||||
* @param sparse indices shape | |||||
* @param sparse values shape | |||||
* @param sparse shape | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape, | |||||
const char* op_name); | |||||
/** | |||||
* DecodeWavShapeFn, infereshape funtion of DecodeWav op | |||||
* @param op Operator | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus DecodeWavShapeFn(Operator& op); | |||||
/** | |||||
* EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||||
* @param op Operator | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus EncodeWavShapeFn(Operator& op); | |||||
/** | |||||
* EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||||
* @param op Operator | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus EncodeWavShapeFn(Operator& op); | |||||
/** | |||||
* Infereshape funtion of SparseSegmentReduction op | |||||
* @param op Operator | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus SparseSegmentReductionShapeFn(Operator& op); | |||||
/** | |||||
* Infereshape funtion of SparseSegmentReductionGrad op | |||||
* @param op Operator | |||||
* @return status whether Shape's condition Satisfied | |||||
*/ | |||||
graphStatus SparseSegmentReductionGradShapeFn(Operator& op); | |||||
/** | |||||
* Validates variable resource handle | |||||
* @param op Operator | |||||
* @param shape_and_type ShapeAndType vector | |||||
* @return status whether this operation success | |||||
*/ | |||||
graphStatus ValidateVariableResourceHandle(Operator& op, std::vector<ShapeAndType>& shape_and_type); | |||||
/** | |||||
* Fill op_desc with input shape | |||||
* @param op_desc Operator desc ptr | |||||
* @param shape input tensor shape | |||||
* @param shape input tensor datatype | |||||
*/ | |||||
void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type = DT_FLOAT); | |||||
/** | |||||
* InferShapeErrorReport info | |||||
* @param op_name Operator name | |||||
* @param op_type Operator type | |||||
* @param value Operator value | |||||
* @param reason error reason | |||||
*/ | |||||
void InferShapeErrorReport(const std::string& op_name, const std::string& op_type, | |||||
const std::string& value, const std::string& reason); | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ |
@@ -0,0 +1,60 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file error_code.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||||
namespace ge { | |||||
// error code for report purpose. | |||||
// 30000~34999 for aicpu engine error | |||||
// and 35000~39999 for infershape error of aicpu op | |||||
enum ViewErrorCode { | |||||
INVALID_INFER_SHAPE = 14001, | |||||
INVALID_INPUT_SHAPE = 35000, | |||||
INVALID_ATTR_VALUE = 35001, | |||||
INVALID_ATTR_SIZE = 35002, | |||||
OTHER_ERROR = 35003, | |||||
INVALID_CONV_ATTR_VALUE = 50029, | |||||
INVALID_CONV_SET_ATTR = 50057, | |||||
INVALID_CONV_SHAPE = 50058, | |||||
INVALID_MISS_INPUT = 70001, | |||||
INVALID_INPUT_FORMAT = 70002, | |||||
INVALID_INPUT_DTYPE = 70003, | |||||
INVALID_INPUT_TYPE = 70004, | |||||
INVALID_GET_ATTR = 70005, | |||||
INVALID_SET_ATTR = 70006, | |||||
INVALID_OPS_ATTR_VALUE = 70007, | |||||
FAILED_UPDATE_OP = 70008, | |||||
INVALID_SHAPE = 70009, | |||||
INVALID_SHAPE_SIZE = 70010, | |||||
INVALID_SHAPE_DIM = 70011, | |||||
INVALID_BROADCAST_SHAPE = 70012, | |||||
INVALID_TWO_INPUT_DTYPE = 70013, | |||||
INVALID_AIPP_ERROR = 70014, | |||||
INVALID_ONE_INPUT_SHAPE = 70015, | |||||
INVALID_TWO_INPUT_SHAPE = 70016, | |||||
INVALID_ONE_OUTPUT_SHAPE = 70017, | |||||
FAILED_GET_COMPILIE_PARAMS = 70018, | |||||
}; | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ |
@@ -0,0 +1,318 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file error_util.cpp | |||||
* \brief | |||||
*/ | |||||
#include <map> | |||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "error_util.h" | |||||
#include "error_code.h" | |||||
#include "op_log.h" | |||||
using namespace std; | |||||
using namespace ge; | |||||
namespace ge { | |||||
inline static std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode) { | |||||
return "E" + std::to_string(errCode); | |||||
} | |||||
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||||
const std::string& correct_shape) { | |||||
map<string, string> err_map; | |||||
err_map["index"] = std::to_string(index); | |||||
err_map["opname"] = opname; | |||||
err_map["wrong_shape"] = wrong_shape; | |||||
err_map["correct_shape"] = correct_shape; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_SHAPE); | |||||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||||
const std::string& correct_value) { | |||||
map<string, string> err_map; | |||||
err_map["attrname"] = attrName; | |||||
err_map["opname"] = opname; | |||||
err_map["wrong_value"] = wrong_value; | |||||
err_map["correct_value"] = correct_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_VALUE); | |||||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||||
const std::string& correct_size) { | |||||
map<string, string> err_map; | |||||
err_map["attrname"] = attrName; | |||||
err_map["opname"] = opname; | |||||
err_map["wrong_size"] = wrong_size; | |||||
err_map["correct_size"] = correct_size; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_SIZE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg) { | |||||
map<string, string> err_map; | |||||
err_map["opname"] = opname; | |||||
err_map["err_msg"] = err_msg; | |||||
string report_error_code = GetViewErrorCodeStr(ViewErrorCode::OTHER_ERROR); | |||||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_MISS_INPUT); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_format_list, const std::string& data_format) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["expected_format_list"] = expected_format_list; | |||||
err_map["format"] = data_format; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_FORMAT); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_data_type_list, const std::string& data_type) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["expected_data_type_list"] = expected_data_type_list; | |||||
err_map["data_type"] = data_type; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_DTYPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||||
const std::string& actual_type) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["param_type"] = param_type; | |||||
err_map["actual_type"] = actual_type; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_TYPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_GET_ATTR); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SET_ATTR); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||||
const std::string& input_value) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["excepted_value"] = excepted_value; | |||||
err_map["input_value"] = input_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_OPS_ATTR_VALUE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_UPDATE_OP); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||||
const std::string& param_value) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["rule_desc"] = rule_desc; | |||||
err_map["param_name"] = param_name; | |||||
err_map["param_value"] = param_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& error_detail) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["error_detail"] = error_detail; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_INPUT_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||||
const std::string& param_name2, const std::string& error_detail) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name1"] = param_name1; | |||||
err_map["param_name2"] = param_name2; | |||||
err_map["error_detail"] = error_detail; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& error_detail) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["error_detail"] = error_detail; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_OUTPUT_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_GET_COMPILIE_PARAMS); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||||
const std::string& real_value) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["input_name"] = input_name; | |||||
err_map["max_value"] = max_value; | |||||
err_map["real_value"] = real_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_SIZE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||||
const std::string& min_value, const std::string& real_value) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["max_value"] = max_value; | |||||
err_map["min_value"] = min_value; | |||||
err_map["real_value"] = real_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_DIM); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||||
const std::string& input2_name, const std::string& input1_shape, | |||||
const std::string& input2_shape) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["input1_name"] = input1_name; | |||||
err_map["input2_name"] = input2_name; | |||||
err_map["input1_shape"] = input1_shape; | |||||
err_map["input2_shape"] = input2_shape; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_BROADCAST_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_dtype_list, const std::string& dtype) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["expected_dtype_list"] = expected_dtype_list; | |||||
err_map["dtype"] = dtype; | |||||
std::string report_error_code = "E50034"; | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||||
const std::string& input2_name, const std::string& input1_dtype, | |||||
const std::string& input2_dtype) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["input1_name"] = input1_name; | |||||
err_map["input2_name"] = input2_name; | |||||
err_map["input1_dtype"] = input1_dtype; | |||||
err_map["input2_dtype"] = input2_dtype; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_DTYPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||||
const std::string& data_W) { | |||||
map<string, string> err_map; | |||||
err_map["aipp_output_H"] = aipp_output_H; | |||||
err_map["aipp_output_W"] = aipp_output_W; | |||||
err_map["data_H"] = data_H; | |||||
err_map["data_W"] = data_W; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_AIPP_ERROR); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||||
const std::string& input_value) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param_name"] = param_name; | |||||
err_map["expected_value"] = expected_value; | |||||
err_map["input_value"] = input_value; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_ATTR_VALUE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||||
const std::string& param2_name) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["param1_name"] = param1_name; | |||||
err_map["param2_name"] = param2_name; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SET_ATTR); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description) { | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = op_name; | |||||
err_map["description"] = description; | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SHAPE); | |||||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||||
} | |||||
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||||
const std::string& reason) { | |||||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INFER_SHAPE); | |||||
ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"}, | |||||
{op_name, op_type, value, reason}); | |||||
} | |||||
void CommonRuntimeErrLog(const std::string& opname, const std::string& description){ | |||||
map<string, string> err_map; | |||||
err_map["op_name"] = opname; | |||||
err_map["description"] = description; | |||||
OP_LOGE(opname.c_str(), description); | |||||
(void)ErrorManager::GetInstance().ReportErrMessage("E50058", err_map); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,184 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file error_util.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||||
#include <sstream> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "operator.h" | |||||
namespace ge { | |||||
/* | |||||
* get debug string of vector | |||||
* param[in] v vector | |||||
* return vector's debug string | |||||
*/ | |||||
template <typename T> | |||||
std::string DebugString(const std::vector<T>& v) { | |||||
std::ostringstream oss; | |||||
oss << "["; | |||||
if (v.size() > 0) { | |||||
for (size_t i = 0; i < v.size() - 1; ++i) { | |||||
oss << v[i] << ", "; | |||||
} | |||||
oss << v[v.size() - 1]; | |||||
} | |||||
oss << "]"; | |||||
return oss.str(); | |||||
} | |||||
/* | |||||
* str cat util function | |||||
* param[in] params need concat to string | |||||
* return concatted string | |||||
*/ | |||||
template <typename T> | |||||
std::string ConcatString(T arg) { | |||||
std::ostringstream oss; | |||||
oss << arg; | |||||
return oss.str(); | |||||
} | |||||
template <typename T, typename... Ts> | |||||
std::string ConcatString(T arg, Ts... arg_left) { | |||||
std::ostringstream oss; | |||||
oss << arg; | |||||
oss << ConcatString(arg_left...); | |||||
return oss.str(); | |||||
} | |||||
/* | |||||
* report input shape error of infer shape | |||||
* param[in] index the index of input | |||||
* param[in] opname op name | |||||
* param[in] wrong_shape wrong input shape | |||||
* param[in] correct_shape correct input shape | |||||
* return void | |||||
*/ | |||||
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||||
const std::string& correct_shape); | |||||
/* | |||||
* report attr value error of infer shape | |||||
* param[in] attrname the attr name | |||||
* param[in] opname op name | |||||
* param[in] wrong_value wrong attr value | |||||
* param[in] correct_value correct attr value | |||||
* return void | |||||
*/ | |||||
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||||
const std::string& correct_value); | |||||
/* | |||||
* report attr size error of infer shape | |||||
* param[in] attrname the attr name | |||||
* param[in] opname op name | |||||
* param[in] wrong_size wrong attr size | |||||
* param[in] correct_size correct attr size | |||||
* return void | |||||
*/ | |||||
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||||
const std::string& correct_size); | |||||
/* | |||||
* report common error of infer shape | |||||
* param[in] opname op name | |||||
* param[in] err_msg error message | |||||
* return void | |||||
*/ | |||||
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg); | |||||
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name); | |||||
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_format_list, const std::string& data_format); | |||||
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_data_type_list, const std::string& data_type); | |||||
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||||
const std::string& actual_type); | |||||
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||||
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||||
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||||
const std::string& input_value); | |||||
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name); | |||||
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||||
const std::string& param_value); | |||||
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& error_detail); | |||||
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||||
const std::string& param_name2, const std::string& error_detail); | |||||
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& error_detail); | |||||
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name); | |||||
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||||
const std::string& real_value); | |||||
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||||
const std::string& min_value, const std::string& real_value); | |||||
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||||
const std::string& input2_name, const std::string& input1_shape, | |||||
const std::string& input2_shape); | |||||
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||||
const std::string& expected_dtype_list, const std::string& dtype); | |||||
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||||
const std::string& input2_name, const std::string& input1_dtype, | |||||
const std::string& input2_dtype); | |||||
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||||
const std::string& data_W); | |||||
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||||
const std::string& input_value); | |||||
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||||
const std::string& param2_name); | |||||
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description); | |||||
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||||
const std::string& reason); | |||||
/* | |||||
* log common runtime error | |||||
* param[in] opname op name | |||||
* param[in] error description | |||||
* return void | |||||
*/ | |||||
void CommonRuntimeErrLog(const std::string& opname, const std::string& description); | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ |
@@ -0,0 +1,73 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file op_common_util.h | |||||
* \brief common util for op, in this file only original type or class in C++ allowed | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||||
#include <set> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <iostream> | |||||
#include <sstream> | |||||
template <typename T1, typename T2> | |||||
std::ostream& operator<<(std::ostream& os, const std::pair<T1, T2>& values) { | |||||
os << "[" << values.first << ", " << values.second << "]"; | |||||
return os; | |||||
} | |||||
template <typename T> | |||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) { | |||||
os << "["; | |||||
for (const auto& item : values) { | |||||
os << item << ", "; | |||||
} | |||||
os << "]"; | |||||
return os; | |||||
} | |||||
namespace ops { | |||||
template<typename T> | |||||
std::string to_string(const std::vector<T> &items) { | |||||
std::ostringstream oss; | |||||
oss << "["; | |||||
for (const auto &item: items) { | |||||
oss << item << ", "; | |||||
} | |||||
oss << "]"; | |||||
return oss.str(); | |||||
} | |||||
template<typename T> | |||||
std::string to_string(const std::set<T> &items) { | |||||
std::ostringstream oss; | |||||
oss << "["; | |||||
for (const auto &item: items) { | |||||
oss << item << ", "; | |||||
} | |||||
oss << "]"; | |||||
return oss.str(); | |||||
} | |||||
} // namespace ops | |||||
#endif //OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ |
@@ -0,0 +1,89 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file op_log.h | |||||
* \brief | |||||
*/ | |||||
#ifndef GE_OP_LOG_H | |||||
#define GE_OP_LOG_H | |||||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||||
#include "toolchain/slog.h" | |||||
#else | |||||
#include <utils/Log.h> | |||||
#endif | |||||
#define OPPROTO_SUBMOD_NAME "OP_PROTO" | |||||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||||
#define OP_LOGI(opname, ...) D_OP_LOGI(opname, __VA_ARGS__) | |||||
#define OP_LOGW(opname, ...) D_OP_LOGW(opname, __VA_ARGS__) | |||||
#define OP_LOGE(opname, ...) D_OP_LOGE(opname, __VA_ARGS__) | |||||
#define OP_LOGD(opname, ...) D_OP_LOGD(opname, __VA_ARGS__) | |||||
#define GE_OP_LOGI(opname, ...) GE_D_OP_LOGI(opname, __VA_ARGS__) | |||||
#define GE_OP_LOGW(opname, ...) GE_D_OP_LOGW(opname, __VA_ARGS__) | |||||
#define GE_OP_LOGE(opname, ...) GE_D_OP_LOGE(opname, __VA_ARGS__) | |||||
#define GE_OP_LOGD(opname, ...) GE_D_OP_LOGD(opname, __VA_ARGS__) | |||||
#define FUSION_PASS_LOGI(...) D_FUSION_PASS_LOGI(__VA_ARGS__) | |||||
#define FUSION_PASS_LOGW(...) D_FUSION_PASS_LOGW(__VA_ARGS__) | |||||
#define FUSION_PASS_LOGE(...) D_FUSION_PASS_LOGE(__VA_ARGS__) | |||||
#define FUSION_PASS_LOGD(...) D_FUSION_PASS_LOGD(__VA_ARGS__) | |||||
#else | |||||
#define OP_LOGI(opname, ...) | |||||
#define OP_LOGW(opname, ...) | |||||
#define OP_LOGE(opname, ...) | |||||
#define OP_LOGD(opname, ...) | |||||
#define FUSION_PASS_LOGI(...) | |||||
#define FUSION_PASS_LOGW(...) | |||||
#define FUSION_PASS_LOGE(...) | |||||
#define FUSION_PASS_LOGD(...) | |||||
#endif | |||||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||||
#define D_OP_LOGI(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define D_OP_LOGW(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define D_OP_LOGE(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define D_OP_LOGD(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define GE_D_OP_LOGI(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define GE_D_OP_LOGW(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define GE_D_OP_LOGE(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define GE_D_OP_LOGD(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||||
#define D_FUSION_PASS_LOGI(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define D_FUSION_PASS_LOGW(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define D_FUSION_PASS_LOGE(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define D_FUSION_PASS_LOGD(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#else | |||||
#define D_OP_LOGI(opname, fmt, ...) | |||||
#define D_OP_LOGW(opname, fmt, ...) | |||||
#define D_OP_LOGE(opname, fmt, ...) | |||||
#define D_OP_LOGD(opname, fmt, ...) | |||||
#define D_FUSION_PASS_LOGI(fmt, ...) | |||||
#define D_FUSION_PASS_LOGW(fmt, ...) | |||||
#define D_FUSION_PASS_LOGE(fmt, ...) | |||||
#define D_FUSION_PASS_LOGD(fmt, ...) | |||||
#endif | |||||
#define OP_CHECK(condition, log_func, do_expr) \ | |||||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \ | |||||
do { \ | |||||
if (condition) { \ | |||||
log_func; \ | |||||
do_expr; \ | |||||
} \ | |||||
} while (0) | |||||
#endif //GE_OP_LOG_H |
@@ -0,0 +1,258 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file transfer_shape_according_to_format.cpp | |||||
* \brief set shape according to original format and current format | |||||
*/ | |||||
#include "transfer_shape_according_to_format.h" | |||||
#include "framework/omg/omg_inner_types.h" | |||||
namespace ge { | |||||
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { | |||||
getNewShapeFuncMap = { | |||||
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, | |||||
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, | |||||
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, | |||||
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, | |||||
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, | |||||
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, | |||||
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; | |||||
mapOfDtypeAndC0 = { | |||||
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, | |||||
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, | |||||
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, | |||||
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_C]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newDimVec.push_back(axisValue[AXIS_C]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_C1]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newDimVec.push_back(axisValue[AXIS_C0]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
} else { | |||||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_C]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
} | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
if (ndValue.size() == SIZE_OF_CN) { | |||||
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||||
auto sizeOfOriginalVec = ndValue.size(); | |||||
std::vector<int64_t> newDimVec = ndValue; | |||||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); | |||||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); | |||||
newDimVec.push_back(SHAPE_NUMBER_16); | |||||
newDimVec.push_back(axisValue[AXIS_C0]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
} else { | |||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
CHECK(axisValue.size() <= AXIS_C1, LOG_INFO("AxisValue is not correct!"), return true); | |||||
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; | |||||
newDimVec.push_back(hwc1); | |||||
newDimVec.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); | |||||
newDimVec.push_back(NI); | |||||
newDimVec.push_back(axisValue[AXIS_C0]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
} else { | |||||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_C]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newDimVec.push_back(axisValue[AXIS_C]); | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.size() <= AXIS_Co, LOG_INFO("AxisValue is not correct!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec; | |||||
newDimVec.push_back(axisValue[AXIS_C1]); | |||||
newDimVec.push_back(axisValue[AXIS_H]); | |||||
newDimVec.push_back(axisValue[AXIS_W]); | |||||
newDimVec.push_back(axisValue[AXIS_N]); | |||||
newDimVec.push_back(axisValue[AXIS_Co]); | |||||
newDimVec.push_back(axisValue[AXIS_C0]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(ndValue.empty(), LOG_INFO("ndValue is empty!"), return true); | |||||
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, | |||||
LOG_INFO("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); | |||||
uint32_t sizeOfOriginalVec = ndValue.size(); | |||||
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { | |||||
LOG_INFO("ndValue's dim num is less than 2!"); | |||||
return true; | |||||
} | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
std::vector<int64_t> newDimVec = ndValue; | |||||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); | |||||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); | |||||
newDimVec.push_back(SHAPE_NUMBER_16); | |||||
newDimVec.push_back(axisValue[AXIS_C0]); | |||||
newShape = ge::GeShape(newDimVec); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { | |||||
/* The default new shape is old shape */ | |||||
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; | |||||
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { | |||||
LOG_ERROR("Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||||
return false; | |||||
} | |||||
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { | |||||
LOG_ERROR("currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); | |||||
return false; | |||||
} | |||||
AxisUtil* axisutil_object = new AxisUtil(); | |||||
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); | |||||
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { | |||||
LOG_INFO("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
LOG_INFO("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||||
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; | |||||
CHECK_NOTNULL(getNewShapeFunc); | |||||
std::vector<int64_t> axisValue; | |||||
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { | |||||
axisValue.push_back(1); | |||||
} | |||||
std::vector<int64_t> ndValue; | |||||
uint32_t c0; | |||||
if (mapOfDtypeAndC0.empty()) { | |||||
c0 = SHAPE_NUMBER_16; | |||||
} else { | |||||
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); | |||||
if (iterGetC0 == mapOfDtypeAndC0.end()) { | |||||
LOG_ERROR("Dtype is not support."); | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
c0 = iterGetC0->second; | |||||
} | |||||
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 | |||||
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { | |||||
c0 = SHAPE_DIM_VALUE_C04; | |||||
} | |||||
bool status = axisutil_object->GetAxisValueByOriginFormat( | |||||
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape.GetDims(), c0, axisValue, ndValue); | |||||
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
delete axisutil_object; | |||||
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); | |||||
if (c != nullptr) { | |||||
*c = axisValue[AXIS_C]; | |||||
} | |||||
return true; | |||||
} | |||||
}; // namespace ge |
@@ -0,0 +1,129 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file transfer_shape_according_to_format.h | |||||
* \brief set shape according to original format and current format | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
#include "axis_util.h" | |||||
#include <memory.h> | |||||
#include <functional> | |||||
#include <vector> | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "operator.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "graph/tensor.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "op_log.h" | |||||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||||
#define LOG_INFO(format, args...) printf(format, ##args) | |||||
namespace ge { | |||||
enum OpImplType { | |||||
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | |||||
EN_IMPL_CUSTOM_TIK, // custom tik op | |||||
EN_IMPL_CUSTOM_TBE, // custom tbe op | |||||
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op | |||||
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op | |||||
EN_IMPL_HW_TIK, // Huawei built-in tik op | |||||
EN_IMPL_HW_TBE, // Huawei built-in tbe op | |||||
EN_IMPL_RL, // RL op | |||||
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op | |||||
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op | |||||
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op | |||||
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op | |||||
EN_RESERVED // reserved value | |||||
}; | |||||
const uint32_t SHAPE_NUMBER_16 = 16; | |||||
const uint32_t SHAPE_NUMBER_32 = 32; | |||||
const uint32_t SHAPE_DIM_VALUE_C04 = 4; | |||||
const uint32_t NI = 16; | |||||
const uint32_t MINUS_VALUE_ONE = 1; | |||||
const uint32_t MINUS_VALUE_TWO = 2; | |||||
const uint32_t SIZE_OF_CN = 2; | |||||
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; | |||||
/* The first parameter is axis value, second is new shape and third is | |||||
* op implementation type. */ | |||||
using GetNewShapeByAxisValueAndFormat = | |||||
std::function<bool(ge::GeShape&, const int64_t&, vector<int64_t>&, vector<int64_t>&)>; | |||||
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>; | |||||
struct ShapeAndFormatInfo { | |||||
const ge::GeShape& oldShape; | |||||
ge::GeShape& newShape; | |||||
const ge::Format& oldFormat; | |||||
const ge::Format& newFormat; | |||||
const ge::DataType& currentDataType; | |||||
const int64_t& opImplType; | |||||
}; | |||||
using ShapeAndFormat = struct ShapeAndFormatInfo; | |||||
class ShapeTransferAccordingToFormat { | |||||
public: | |||||
ShapeTransferAccordingToFormat(); | |||||
~ShapeTransferAccordingToFormat(){}; | |||||
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; | |||||
ShapeTransferAccordingToFormat& operator=(const ShapeTransferAccordingToFormat&) = delete; | |||||
bool GetShapeAccordingToFormat(ShapeAndFormat& inputAndOutputInfo, int64_t* c = nullptr); | |||||
/* ----------Below is the function of getting new shape---------------------- */ | |||||
static bool GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue); | |||||
static bool GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue); | |||||
static bool GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||||
static bool GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue); | |||||
static bool GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue); | |||||
static bool GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||||
static bool GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue); | |||||
private: | |||||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
* formats. */ | |||||
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap; | |||||
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0; | |||||
}; | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ |
@@ -0,0 +1,363 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file util.h | |||||
* \brief | |||||
*/ | |||||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||||
#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||||
#include <memory.h> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <map> | |||||
#include <algorithm> | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "operator.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "graph/operator_reg.h" | |||||
#include "transfer_shape_according_to_format.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/tensor.h" | |||||
#include "graph/node.h" | |||||
#include "graph/ge_tensor.h" | |||||
#include "op_log.h" | |||||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||||
namespace ge { | |||||
// enum type and string type mapping | |||||
static const std::map<ge::DataType, std::string> DTYPE_STR_MAP{ | |||||
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"}, | |||||
{ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, | |||||
{ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}}; | |||||
// define the input num of shape | |||||
const size_t INPUT_NUM0 = 0; | |||||
const size_t INPUT_NUM1 = 1; | |||||
const size_t INPUT_NUM2 = 2; | |||||
const size_t INPUT_NUM3 = 3; | |||||
const size_t INPUT_NUM4 = 4; | |||||
const size_t INPUT_NUM5 = 5; | |||||
const size_t INPUT_NUM6 = 6; | |||||
const size_t INPUT_NUM7 = 7; | |||||
const size_t INPUT_NUM8 = 8; | |||||
const size_t INPUT_NUM9 = 9; | |||||
// define the dims size of shape | |||||
const size_t DIM_SIZE0 = 0; | |||||
const size_t DIM_SIZE1 = 1; | |||||
const size_t DIM_SIZE2 = 2; | |||||
const size_t DIM_SIZE3 = 3; | |||||
const size_t DIM_SIZE4 = 4; | |||||
const size_t DIM_SIZE5 = 5; | |||||
const size_t DIM_SIZE6 = 6; | |||||
const size_t DIM_SIZE7 = 7; | |||||
const size_t DIM_SIZE8 = 8; | |||||
// define the index of shape dim | |||||
const size_t DIM_INDEX0 = 0; | |||||
const size_t DIM_INDEX1 = 1; | |||||
const size_t DIM_INDEX2 = 2; | |||||
const size_t DIM_INDEX3 = 3; | |||||
const size_t DIM_INDEX4 = 4; | |||||
const size_t DIM_INDEX5 = 5; | |||||
const size_t DIM_INDEX6 = 6; | |||||
const size_t DIM_INDEX7 = 7; | |||||
const size_t DIM_INDEX8 = 8; | |||||
/* | |||||
* get the datatype of input | |||||
* param[in] dataType input datatype of enum value | |||||
* param[in] supportList the support range of op | |||||
* return true :get type success | |||||
* false:get type failed | |||||
*/ | |||||
bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList); | |||||
bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType); | |||||
/* infer shape of two input and on output with broadcast | |||||
* param[in] op op desc supply by ge | |||||
* param[in] inputName1 first input name | |||||
* param[in] inputName2 second input name | |||||
* param[in] outputName output name | |||||
* return SUCCESS:infer success | |||||
* FAILED:infer failed like unsupported broadcast input shape | |||||
*/ | |||||
bool CheckInputDataType(const Operator& op, const std::string& input_name, | |||||
const std::vector<ge::DataType>& support_list); | |||||
/* | |||||
* check the datatype and shape of input | |||||
* param[in] op the operator | |||||
* param[in] inputTensorMap the map of input name and support datatype | |||||
* param[in] paramType the mode of input param, tensor or scalar | |||||
* return true | |||||
* false | |||||
*/ | |||||
bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap); | |||||
/* | |||||
* infer shape of two input and on output with broadcast | |||||
* param[in] op op desc supply by ge | |||||
* param[in] inputName1 first input name | |||||
* param[in] inputName2 second input name | |||||
* param[in] outputName output name | |||||
* return SUCCESS:infer success | |||||
* FAILED:infer failed like unsupported broadcast input shape | |||||
*/ | |||||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
const string& output_name); | |||||
/* | |||||
* infer shape of two input and on output with broadcast | |||||
* param[in] op op desc supply by ge | |||||
* param[in] inputName1 first input name | |||||
* param[in] inputName2 second input name | |||||
* param[in] outputName output name | |||||
* param[in] is_dynamic whether the shape of output is dynamic shape | |||||
* return SUCCESS:infer success | |||||
* FAILED:infer failed like unsupported broadcast input shape | |||||
*/ | |||||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
const string& output_name, bool& is_dynamic); | |||||
bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2, | |||||
const string& output_name); | |||||
bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name, | |||||
const std::vector<ge::DataType>& supportList); | |||||
bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2); | |||||
bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors); | |||||
bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names); | |||||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value); | |||||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value); | |||||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value); | |||||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value); | |||||
/** | |||||
* Get int type const value from tensor data | |||||
* @param [in] data const tensor data | |||||
* @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64 | |||||
* @param [out] const_values const int values | |||||
* @return true:success, false:failed. | |||||
*/ | |||||
bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values); | |||||
bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, | |||||
std::vector<int64_t>& const_data); | |||||
bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype, | |||||
std::vector<int64_t>& const_data); | |||||
bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data); | |||||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||||
const string& output_name); | |||||
/* | |||||
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||||
* param[in] op op desc supply by ge | |||||
* param[in] inputNumBeg input index begin, [0, N] | |||||
* param[in] inputNumEnd input index end need to be checked | |||||
* param[in] supportList, support type of ge::DataType and ge::Format | |||||
* return true: check pass | |||||
* false: check failed | |||||
*/ | |||||
template <typename T> | |||||
bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd, | |||||
const std::vector<T>& supportList) { | |||||
for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) { | |||||
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||||
ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
if (findDtype == supportList.end()) { | |||||
return false; | |||||
} | |||||
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||||
ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
if (findDtype == supportList.end()) { | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
/* | |||||
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||||
* param[in] op op desc supply by ge | |||||
* param[in] indexNeedCheck input index need to be checked | |||||
* param[in] supportList, support type of ge::DataType and ge::Format | |||||
* return true: check pass | |||||
* false: check failed | |||||
*/ | |||||
template <typename T> | |||||
bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector<std::size_t>& indexNeedCheck, | |||||
const std::vector<T>& supportList) { | |||||
for (auto i : indexNeedCheck) { | |||||
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||||
ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
if (findDtype == supportList.end()) { | |||||
return false; | |||||
} | |||||
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||||
ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||||
if (findDtype == supportList.end()) { | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
/* | |||||
* get const attr | |||||
* param[in] op op desc supply by ge | |||||
* param[in] attrName list need to be get | |||||
* param[out] attr vector | |||||
* return true: get success | |||||
* false: get failed | |||||
*/ | |||||
template <typename T> | |||||
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, std::vector<T>& attrVec) { | |||||
T value; | |||||
for (auto name : attrNameList) { | |||||
if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) { | |||||
return false; | |||||
} | |||||
attrVec.push_back(value); | |||||
} | |||||
return true; | |||||
} | |||||
/* | |||||
* get const attr list | |||||
* param[in] op op desc supply by ge | |||||
* param[in] attrName list need to be get | |||||
* param[out] attr vector | |||||
* return true: get success | |||||
* false: get failed | |||||
*/ | |||||
template <typename T> | |||||
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, | |||||
std::vector<std::vector<T>>& attrListVec) { | |||||
for (auto name : attrNameList) { | |||||
std::vector<T> valueList; | |||||
if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) { | |||||
return false; | |||||
} | |||||
attrListVec.push_back(valueList); | |||||
} | |||||
return true; | |||||
} | |||||
std::string to_string(const vector<int64_t>& shape); | |||||
std::string to_string(const ge::Shape& shape); | |||||
std::string to_string(const ge::GeShape& shape); | |||||
std::string to_string(const vector<pair<int64_t, int64_t>>& ranges); | |||||
class DynamicShapeInfer { | |||||
public: | |||||
std::map<std::string, Format> map_format; | |||||
std::map<std::string, DataType> map_dtype; | |||||
std::map<std::string, uint32_t> inputs; | |||||
std::map<std::string, uint32_t> outputs; | |||||
Operator& op; | |||||
OpDescPtr& op_desc; | |||||
DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) { | |||||
} | |||||
bool CatchFormatAndShape(); | |||||
bool UpdateFormatAndShape(); | |||||
~DynamicShapeInfer() { | |||||
UpdateFormatAndShape(); | |||||
} | |||||
}; | |||||
#define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\ | |||||
do { \ | |||||
if (!depends_names.empty()) { \ | |||||
op_desc->SetOpInferDepends(depends_names); \ | |||||
} \ | |||||
} while(0) | |||||
bool IsEmptyTensor(const std::vector<int64_t>& dims); | |||||
bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||||
bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec); | |||||
bool IsUnKnownShape(const std::vector<int64_t>& shape_vec); | |||||
bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||||
bool IsUnknownVec(std::vector<int64_t>& shape_vec); | |||||
bool IsUnknown(const std::vector<int64_t>& shape_vec); | |||||
void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range); | |||||
std::string DataTypeToStringDesc(const ge::DataType& dataType); | |||||
bool OneInOneOutDynamicInfer(const Operator& op, | |||||
const std::string& input_name, | |||||
const std::vector<std::string>& output_name_list); | |||||
bool TwoInOneOutDynamicInferNoBroadcast(Operator& op, | |||||
const string& input1_name, | |||||
const string& input2_name, | |||||
const std::vector<string>& output_name_list); | |||||
void FixShapeRangeWithDims(const std::vector<int64_t>& dims, | |||||
std::vector<int64_t>& shape_1, | |||||
std::vector<int64_t>& shape_2, | |||||
std::vector<std::pair<int64_t, int64_t>>& range_1, | |||||
std::vector<std::pair<int64_t, int64_t>>& range_2); | |||||
bool SetScalarOutputDesc(const string& input, | |||||
const string& output, | |||||
OpDescPtr op_desc, | |||||
GeShape& output_shape); | |||||
namespace array_ops { | |||||
bool CheckInt64MulOverflow(int64_t a, int64_t b); | |||||
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||||
int64_t& range_max); | |||||
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||||
std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape); | |||||
} | |||||
} // namespace ge | |||||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph_assertion.h" |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||||
#define GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||||
/* | |||||
* Compare graph node size, node_attr | |||||
*/ | |||||
#define ASSERT_GRAPH_EQUAL(g1,g2) \ | |||||
do { \ | |||||
} while (0) | |||||
#define ASSERT_GRAPH_CORRECT(g) \ | |||||
do { \ | |||||
} while (0) | |||||
#define ASSERT_GRAPH_SHAPE_CONTINOUS(g) \ | |||||
do { \ | |||||
} while (0) | |||||
#endif // GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H |
@@ -0,0 +1,48 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph_builder_utils.h" | |||||
#include "inc/external/graph/operator.h" | |||||
#include "inc/external/graph/operator_factory.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | |||||
namespace st { | |||||
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||||
DataType data_type, std::vector<int64_t> shape) { | |||||
auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
tensor_desc->SetFormat(format); | |||||
tensor_desc->SetDataType(data_type); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (int i = 0; i < in_cnt; ++i) { | |||||
op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
} | |||||
for (int i = 0; i < out_cnt; ++i) { | |||||
op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
} | |||||
return graph_->AddNode(op_desc); | |||||
} | |||||
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||||
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
} | |||||
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||||
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
} | |||||
} // namespace st | |||||
} // namespace ge |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
#include <string> | |||||
#include <vector> | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/graph.h" | |||||
#include "graph/node.h" | |||||
namespace ge { | |||||
namespace st { | |||||
class ComputeGraphBuilder { | |||||
public: | |||||
explicit ComputeGraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||||
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||||
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||||
ComputeGraphPtr GetComputeGraph() { | |||||
graph_->TopologicalSorting(); | |||||
return graph_; | |||||
} | |||||
Graph GetGraph() { | |||||
graph_->TopologicalSorting(); | |||||
return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||||
} | |||||
private: | |||||
ComputeGraphPtr graph_; | |||||
}; | |||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H |
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "tensor_builder_utils.h" |
@@ -0,0 +1,22 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
class tensor_builder_utils {}; | |||||
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H |
@@ -0,0 +1,15 @@ | |||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
add_executable(graph_engine_test ${SOURCES}) | |||||
target_include_directories(graph_engine_test | |||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
) | |||||
set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 11) | |||||
target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) | |||||
include(CTest) | |||||
enable_testing() | |||||
add_test(NAME test COMMAND graph_engine_test) |
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <map> | |||||
#include "external/ge/ge_api.h" | |||||
#include "framework/common/types.h" | |||||
#include "framework.h" | |||||
#include "framework/utils/builder/graph_builder_utils.h" | |||||
using namespace std; | |||||
using namespace ge; | |||||
class FrameworkTest : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
// ge initialize | |||||
map<AscendString, AscendString> options; | |||||
auto ret = ge::GEInitialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(FrameworkTest, test_framework_dummy) { | |||||
// build graph | |||||
st::ComputeGraphBuilder graphBuilder("g1"); | |||||
auto data1 = graphBuilder.AddNode("data1",DATA,1,1); | |||||
auto data2 = graphBuilder.AddNode("data2",DATA,1,1); | |||||
auto add = graphBuilder.AddNode("add",ADD,2,1); | |||||
graphBuilder.AddDataEdge(data1, 0, add,0); | |||||
graphBuilder.AddDataEdge(data2, 0, add,1); | |||||
Graph graph = graphBuilder.GetGraph(); | |||||
// new session & add graph | |||||
map<AscendString, AscendString> options; | |||||
Session session(options); | |||||
auto ret = session.AddGraph(1, graph, options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// build input tensor | |||||
std::vector<InputTensorInfo> inputs; | |||||
// build_graph through session | |||||
ret = session.BuildGraph(1, inputs); | |||||
// TODO check result | |||||
} |