modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/test.cc new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.cc modified: CMakeLists.txt modified: build.sh modified: ge/ge_runtime/runtime_model.cc modified: ge/ge_runtime/task/aicpu_task.cc modified: ge/ge_runtime/task/hccl_task.cc modified: ge/ge_runtime/task/label_goto_task.cc modified: ge/ge_runtime/task/label_switch_task.cc new file: tests/st/CMakeLists.txt new file: tests/st/cmake/graphengine.cmake new file: tests/st/framework/CMakeLists.txt new file: tests/st/framework/framework.cc new file: tests/st/framework/framework.h new file: tests/st/framework/stub_engine/CMakeLists.txt new file: tests/st/framework/stub_engine/common/constant/constant.h new file: tests/st/framework/stub_engine/engine/stub_engine.cc new file: tests/st/framework/stub_engine/engine/stub_engine.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc new file: tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op.h new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc new file: tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h new file: tests/st/framework/stub_engine/proto/task.proto new file: tests/st/framework/stub_op_proto/array_ops.cc new file: tests/st/framework/stub_op_proto/array_ops.h new file: tests/st/framework/stub_op_proto/control_flow_ops.cc new file: tests/st/framework/stub_op_proto/control_flow_ops.h new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.cc new file: tests/st/framework/stub_op_proto/elewise_calculation_ops.h new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h new file: tests/st/framework/stub_op_proto/util/axis_util.cc new file: tests/st/framework/stub_op_proto/util/axis_util.h new file: tests/st/framework/stub_op_proto/util/common_shape_fns.cc new file: tests/st/framework/stub_op_proto/util/common_shape_fns.h new file: tests/st/framework/stub_op_proto/util/error_code.h new file: tests/st/framework/stub_op_proto/util/error_util.cc new file: tests/st/framework/stub_op_proto/util/error_util.h new file: tests/st/framework/stub_op_proto/util/op_common_util.h new file: tests/st/framework/stub_op_proto/util/op_log.h new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc new file: tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h new file: tests/st/framework/stub_op_proto/util/util.cc new file: tests/st/framework/stub_op_proto/util/util.h new file: tests/st/framework/utils/assertion/graph_assertion.cc new file: tests/st/framework/utils/assertion/graph_assertion.h new file: tests/st/framework/utils/builder/graph_builder_utils.cc new file: tests/st/framework/utils/builder/graph_builder_utils.h new file: tests/st/framework/utils/builder/tensor_builder_utils.cc new file: tests/st/framework/utils/builder/tensor_builder_utils.h new file: tests/st/testcase/CMakeLists.txt new file: tests/st/testcase/test_framework_dummy.ccpull/1696/head
@@ -39,7 +39,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR} | |||
option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | |||
if (ENABLE_OPEN_SRC) | |||
if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||
set(HI_PYTHON python3) | |||
include(cmake/external_libs/protobuf_shared.cmake) | |||
@@ -51,118 +51,132 @@ if (ENABLE_OPEN_SRC) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# if D_LINK_PATH is set in environment variables, search libraries in given path | |||
if(DEFINED ENV{D_LINK_PATH}) | |||
# D_LINK_PATH is set | |||
set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||
set(GE_SYS_ARCH "") | |||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||
# x86 ubuntu | |||
set(GE_SYS_ARCH "x86_64") | |||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||
# arm euleros | |||
set(GE_SYS_ARCH "aarch64") | |||
add_subdirectory(tests) | |||
else () | |||
if (ENABLE_OPEN_SRC) | |||
set(HI_PYTHON python3) | |||
include(cmake/external_libs/protobuf_shared.cmake) | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/gflags.cmake) | |||
include(cmake/external_libs/gtest.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# if D_LINK_PATH is set in environment variables, search libraries in given path | |||
if(DEFINED ENV{D_LINK_PATH}) | |||
# D_LINK_PATH is set | |||
set(GE_LIB_PATH $ENV{D_LINK_PATH}) | |||
set(GE_SYS_ARCH "") | |||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") | |||
# x86 ubuntu | |||
set(GE_SYS_ARCH "x86_64") | |||
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") | |||
# arm euleros | |||
set(GE_SYS_ARCH "aarch64") | |||
else() | |||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||
endif() | |||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
find_module(slog libalog.so ${GE_LIB_PATH}) | |||
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||
find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||
find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||
find_module(resource libresource.so ${GE_LIB_PATH}) | |||
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
else() | |||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||
endif() | |||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
find_module(slog libalog.so ${GE_LIB_PATH}) | |||
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | |||
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH}) | |||
find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||
find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||
find_module(resource libresource.so ${GE_LIB_PATH}) | |||
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | |||
add_subdirectory(tests) | |||
else() | |||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||
if(PLATFORM STREQUAL "train") | |||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||
if(PLATFORM STREQUAL "train") | |||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||
if(PRODUCT STREQUAL "flr3") | |||
message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||
endif() | |||
elseif(PLATFORM STREQUAL "inference") | |||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||
if(PRODUCT STREQUAL "flr3") | |||
elseif(PRODUCT STREQUAL "flr1") | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||
elseif(PRODUCT STREQUAL "flr2") | |||
# flr2 ascend_hal_stub limsprof ? | |||
else() | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||
endif() | |||
elseif(PLATFORM STREQUAL "all") | |||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||
if(PRODUCT STREQUAL "flr3") | |||
message(FATAL_ERROR "This platform is not supported in train mode, build terminated") | |||
endif() | |||
elseif(PLATFORM STREQUAL "inference") | |||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||
if(PRODUCT STREQUAL "flr3") | |||
elseif(PRODUCT STREQUAL "flr1") | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||
elseif(PRODUCT STREQUAL "flr2") | |||
# flr2 ascend_hal_stub limsprof ? | |||
else() | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||
endif() | |||
elseif(PLATFORM STREQUAL "all") | |||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||
else() | |||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||
endif() | |||
endif() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||
add_subdirectory(metadef) | |||
add_subdirectory(parser) | |||
#add_subdirectory(metadef/graph) | |||
#add_subdirectory(metadef/register) | |||
elseif (ENABLE_D OR ENABLE_ACL) | |||
# compiling with MindSpore | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# common libraries | |||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
if (ENABLE_D) | |||
# training | |||
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
endif () | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
add_subdirectory(metadef) | |||
elseif(ENABLE_MS_TESTCASES) | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# common libraries | |||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser) | |||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||
add_subdirectory(metadef) | |||
add_subdirectory(parser) | |||
#add_subdirectory(metadef/graph) | |||
#add_subdirectory(metadef/register) | |||
elseif (ENABLE_D OR ENABLE_ACL) | |||
# compiling with MindSpore | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# common libraries | |||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
if (ENABLE_D) | |||
# training | |||
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
endif () | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
add_subdirectory(metadef) | |||
elseif(ENABLE_MS_TESTCASES) | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# common libraries | |||
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH}) | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
add_subdirectory(metadef) | |||
else() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||
endif() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
add_subdirectory(metadef) | |||
else() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..) | |||
endif() | |||
add_subdirectory(ge) | |||
add_subdirectory(ge) | |||
endif () |
@@ -177,6 +177,9 @@ build_graphengine() | |||
elif [ "X$ENABLE_GE_UT" = "Xon" ] | |||
then | |||
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | |||
elif [ "X$ENABLE_GE_ST" = "Xon" ] | |||
then | |||
TARGET="graph_engine_test" | |||
elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||
then | |||
TARGET="ge_common graph" | |||
@@ -234,6 +237,27 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||
genhtml coverage.info | |||
fi | |||
if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||
#prepare engine & opskernel so | |||
mkdir -p ${OUTPUT_PATH}/plugin/nnengine | |||
mkdir -p ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||
mkdir -p ${OUTPUT_PATH}/plugin/opskernel | |||
cp ${BUILD_PATH}/tests/st/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine | |||
cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config | |||
cp ${BUILD_PATH}/tests/st/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel | |||
#prepare st execution bin | |||
cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} | |||
#execute st testcase | |||
RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} | |||
if [[ "$?" -ne 0 ]]; then | |||
echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||
echo -e "\033[31m${RUN_TEST_CASE}\033[0m" | |||
exit 1; | |||
fi | |||
# remove plugin | |||
rm -rf ${OUTPUT_PATH}/plugin | |||
fi | |||
# generate output package in tar form, including ut/st libraries/executables | |||
generate_package() | |||
{ | |||
@@ -337,7 +361,7 @@ generate_package() | |||
fi | |||
} | |||
if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||
if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$ENABLE_GE_ST" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then | |||
generate_package | |||
elif [ "X$MINDSPORE_MODE" = "Xon" ] | |||
then | |||
@@ -25,6 +25,7 @@ | |||
#include "framework/common/op/op_parser_util.h" | |||
#include "graph/types.h" | |||
#include "task/task_factory.h" | |||
#include "ge/common/math/math_util.h" | |||
namespace ge { | |||
namespace model_runner { | |||
@@ -500,7 +501,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||
} | |||
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data())); | |||
uint32_t head_len = kOffsetUnit * kStringHeadElems; | |||
if (ge::CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||
if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) { | |||
GELOGE(FAILED, "Shape size is invalid"); | |||
return false; | |||
} | |||
@@ -83,7 +83,7 @@ bool AicpuTask::Distribute() { | |||
return false; | |||
} | |||
GELOGI("ext info size:", ext_size); | |||
GELOGI("ext info size: %u", ext_size); | |||
aicpu_param_head.extInfoLength = ext_size; | |||
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||
} | |||
@@ -130,7 +130,7 @@ bool HcclTask::SetSecondaryStream() { | |||
Status ret; | |||
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_); | |||
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { | |||
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id); | |||
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id); | |||
ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
if (!ret) { | |||
GELOGE(RT_FAILED, "Create hccl stream failed."); | |||
@@ -189,7 +189,7 @@ bool HcclTask::SetSecondaryStream() { | |||
} | |||
GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); | |||
} else { | |||
GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(), | |||
GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(), | |||
master_stream_id); | |||
ret = CreateStream(hccl_secondary_stream_num, master_stream_id); | |||
if (!ret) { | |||
@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() { | |||
return false; | |||
} | |||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret); | |||
return false; | |||
@@ -69,7 +69,7 @@ bool LabelSwitchTask::Distribute() { | |||
return false; | |||
} | |||
label_list[i] = all_label_resource_[label_index]; | |||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||
GELOGI("Case %zu: label id %zu.", i, (size_t)label_index); | |||
} | |||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
@@ -0,0 +1,6 @@ | |||
project(graphengine_st) | |||
include(cmake/graphengine.cmake) | |||
add_subdirectory(framework) | |||
add_subdirectory(testcase) |
@@ -0,0 +1,249 @@ | |||
# ---- Test coverage ---- | |||
if (ENABLE_GE_COV) | |||
set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -fPIC -O0 -ftest-coverage") | |||
set(CMAKE_CXX_FLAGS "${COVERAGE_COMPILER_FLAGS}") | |||
endif() | |||
# ---- Proto generate ---- | |||
file(GLOB_RECURSE PROTO_FILES CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/proto/*.proto") | |||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_FILES}) | |||
# ---- File glob by group ---- | |||
file(GLOB_RECURSE METADEF_SRCS CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/metadef/graph/*.cc" | |||
"${GE_CODE_DIR}/metadef/register/*.cc" | |||
"${GE_CODE_DIR}/metadef/register/*.cpp" | |||
"${GE_CODE_DIR}/metadef/ops/*.cc" | |||
"${GE_CODE_DIR}/metadef/third_party/transformer/src/*.cc" | |||
) | |||
file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/metadef/register/*.cc" | |||
"${GE_CODE_DIR}/metadef/register/*.cpp" | |||
) | |||
file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/parser/parser/common/*.cc" | |||
) | |||
file(GLOB_RECURSE LOCAL_ENGINE_SRC CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/ge/ge_local_engine/*.cc" | |||
) | |||
file(GLOB_RECURSE HOST_ENGINE_SRC CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/ge/host_cpu_engine/*.cc" | |||
) | |||
file(GLOB_RECURSE NN_ENGINE_SRC CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/ge/plugin/*.cc" | |||
) | |||
file(GLOB_RECURSE OFFLINE_SRC CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/ge/offline/*.cc" | |||
) | |||
file(GLOB_RECURSE GE_SRCS CONFIGURE_DEPENDS | |||
"${GE_CODE_DIR}/ge/*.cc" | |||
) | |||
list(REMOVE_ITEM GE_SRCS ${LOCAL_ENGINE_SRC} ${HOST_ENGINE_SRC} ${NN_ENGINE_SRC} ${OFFLINE_SRC}) | |||
list(APPEND INCLUDE_DIRECTORIES | |||
"${CMAKE_CURRENT_SOURCE_DIR}" | |||
"${GE_CODE_DIR}" | |||
"${GE_CODE_DIR}/inc" | |||
"${GE_CODE_DIR}/metadef/inc" | |||
"${GE_CODE_DIR}/ge" | |||
"${GE_CODE_DIR}/ge/inc" | |||
"${GE_CODE_DIR}/ge/ir_build" | |||
"${GE_CODE_DIR}/metadef" | |||
"${GE_CODE_DIR}/metadef/graph" | |||
"${GE_CODE_DIR}/inc/external" | |||
"${GE_CODE_DIR}/inc/framework/common" | |||
"${GE_CODE_DIR}/metadef/inc/external" | |||
"${GE_CODE_DIR}/metadef/inc/external/graph" | |||
"${GE_CODE_DIR}/metadef/inc/graph" | |||
"${GE_CODE_DIR}/inc/framework" | |||
"${GE_CODE_DIR}/metadef/inc/common" | |||
"${GE_CODE_DIR}/metadef/third_party" | |||
"${GE_CODE_DIR}/metadef/third_party/transformer/inc" | |||
"${GE_CODE_DIR}/parser" | |||
"${GE_CODE_DIR}/parser/parser" | |||
"${GE_CODE_DIR}/third_party/fwkacllib/inc" | |||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | |||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | |||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | |||
"${GE_CODE_DIR}/tests/ut/ge" | |||
"${GE_CODE_DIR}/tests/ut/common" | |||
"${CMAKE_BINARY_DIR}" | |||
"${CMAKE_BINARY_DIR}/proto/ge" | |||
"${CMAKE_BINARY_DIR}/proto/ge/proto" | |||
) | |||
list(APPEND STUB_LIBS | |||
c_sec | |||
slog_stub | |||
cce_ge_stub | |||
runtime_stub | |||
profiler_stub | |||
#mmpa_stub | |||
hccl_stub | |||
error_manager_stub | |||
ascend_protobuf | |||
json | |||
) | |||
# ---- Target : Local engine ---- | |||
add_library(localengine STATIC ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) | |||
target_include_directories(localengine | |||
PUBLIC | |||
"${INCLUDE_DIRECTORIES}" | |||
"${GE_CODE_DIR}/ge/ge_local_engine" | |||
) | |||
target_compile_definitions(localengine PRIVATE | |||
google=ascend_private | |||
) | |||
target_compile_options(localengine PRIVATE | |||
-g --coverage -fprofile-arcs -ftest-coverage | |||
-Werror=format | |||
) | |||
target_link_libraries(localengine PUBLIC | |||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||
) | |||
set_target_properties(localengine PROPERTIES CXX_STANDARD 11) | |||
# ---- Target : metadef graph ---- | |||
add_library(metadef_graph STATIC ${METADEF_SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | |||
target_include_directories(metadef_graph | |||
PUBLIC | |||
"${INCLUDE_DIRECTORIES}" | |||
) | |||
target_compile_definitions(metadef_graph PRIVATE | |||
google=ascend_private | |||
FMK_SUPPORT_DUMP | |||
) | |||
target_compile_options(metadef_graph PRIVATE | |||
-g --coverage -fprofile-arcs -ftest-coverage | |||
-Werror=format | |||
) | |||
target_link_libraries(metadef_graph PUBLIC | |||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||
) | |||
set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) | |||
# ---- Target : Host engine ---- | |||
add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC} ${PROTO_HDRS}) | |||
target_include_directories(host_cpu_engine | |||
PUBLIC | |||
"${INCLUDE_DIRECTORIES}" | |||
"${GE_CODE_DIR}/ge/host_cpu_engine" | |||
) | |||
target_compile_definitions(host_cpu_engine PRIVATE | |||
google=ascend_private | |||
FMK_SUPPORT_DUMP | |||
) | |||
target_compile_options(host_cpu_engine PRIVATE | |||
-g --coverage -fprofile-arcs -ftest-coverage | |||
-Werror=format | |||
) | |||
target_link_libraries(host_cpu_engine PUBLIC | |||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lmmpa -L/home/hugo/Code/ge/graphengine/build/tests/depends/mmpa -lrt -ldl -lpthread -lgcov | |||
) | |||
set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) | |||
# ---- Target : engine plugin---- | |||
# | |||
add_library(nnengine SHARED ${NN_ENGINE_SRC}) | |||
target_include_directories(nnengine | |||
PUBLIC | |||
"${INCLUDE_DIRECTORIES}" | |||
"${GE_CODE_DIR}/ge/plugin/engine" | |||
) | |||
target_compile_definitions(nnengine PRIVATE | |||
google=ascend_private | |||
) | |||
target_compile_options(nnengine PRIVATE | |||
-g --coverage -fprofile-arcs -ftest-coverage | |||
-Werror=format | |||
) | |||
target_link_libraries(nnengine PUBLIC | |||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov | |||
) | |||
set_target_properties(nnengine PROPERTIES CXX_STANDARD 11) | |||
# Targe: engine_conf | |||
add_custom_target( | |||
engine_conf.json ALL | |||
DEPENDS ${CMAKE_BINARY_DIR}/engine_conf.json | |||
) | |||
add_custom_command( | |||
OUTPUT ${CMAKE_BINARY_DIR}/engine_conf.json | |||
COMMAND cp ${GE_CODE_DIR}/ge/engine_manager/engine_conf.json ${CMAKE_BINARY_DIR}/ | |||
) | |||
# Targe: optimizer priority | |||
add_custom_target( | |||
optimizer_priority.pbtxt ALL | |||
DEPENDS ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||
) | |||
add_custom_command( | |||
OUTPUT ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt | |||
COMMAND cp ${GE_CODE_DIR}/ge/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_BINARY_DIR}/ | |||
) | |||
# ---- Target : Graph engine ---- | |||
add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS} ${PROTO_HDRS}) | |||
target_include_directories(graphengine | |||
PUBLIC | |||
"${INCLUDE_DIRECTORIES}" | |||
"${GE_CODE_DIR}/ge/host_cpu_engine" | |||
) | |||
target_compile_definitions(graphengine PRIVATE | |||
google=ascend_private | |||
FMK_SUPPORT_DUMP | |||
) | |||
target_compile_options(graphengine PRIVATE | |||
-g --coverage -fprofile-arcs -ftest-coverage | |||
-Werror=format | |||
) | |||
target_link_libraries(graphengine PUBLIC | |||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} | |||
metadef_graph | |||
localengine | |||
host_cpu_engine | |||
nnengine | |||
mmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov | |||
) | |||
set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) | |||
add_dependencies(graphengine engine_conf.json optimizer_priority.pbtxt) |
@@ -0,0 +1,16 @@ | |||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||
#todo | |||
file(GLOB_RECURSE stub_engine CONFIGURE_DEPENDS | |||
"stub_engine/*.cc" | |||
) | |||
list(REMOVE_ITEM SOURCES ${stub_engine}) | |||
add_library(framework STATIC ${SOURCES}) | |||
target_include_directories(framework | |||
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} | |||
) | |||
set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||
target_link_libraries(framework PUBLIC graphengine) |
@@ -0,0 +1,26 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <stdlib.h> | |||
#include "framework.h" | |||
namespace ge { | |||
namespace st { | |||
Status Framework::SetUp() { | |||
} | |||
} // namespace st | |||
} // namespace ge |
@@ -0,0 +1,33 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||
#define GRAPHENGINE_LLT_ST_FRAMEWORK_H_ | |||
#include <string> | |||
#include "common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace st { | |||
class Framework { | |||
public: | |||
explicit Framework() {}; | |||
Status SetUp(); | |||
Status TearDown(); | |||
}; | |||
} // namespace st | |||
}// namespace ge | |||
#endif // GRAPHENGINE_LLT_ST_FRAMEWORK_H_ |
@@ -0,0 +1,259 @@ | |||
set(PROTO_LIST | |||
"${METADEF_DIR}/proto/task.proto" | |||
) | |||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) | |||
set(SRC_LIST | |||
"engine/stub_engine.cc" | |||
"ops_kernel_store/host_cpu_ops_kernel_info.cc" | |||
"ops_kernel_store/op/op_factory.cc" | |||
"ops_kernel_store/op/host_op.cc" | |||
) | |||
set(CPU_OPS_KERNEL_LIST | |||
"ops_kernel_store/host_cpu_ops_kernel_builder.cc" | |||
) | |||
############ libfe.so ############ | |||
add_library(fe SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||
target_compile_options(fe PRIVATE | |||
-Werror | |||
-fno-common | |||
-fvisibility=hidden | |||
) | |||
target_compile_definitions(fe PRIVATE | |||
google=ascend_private | |||
FUNC_VISIBILITY | |||
) | |||
target_include_directories(fe PRIVATE | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
#### blue zone #### | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
) | |||
target_link_options(fe PRIVATE | |||
-Wl,-Bsymbolic | |||
) | |||
target_link_libraries(fe PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
c_sec | |||
graph | |||
slog | |||
-Wl,--as-needed | |||
) | |||
############ atcstub/libfe.so ############ | |||
add_library(atc_fe SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS}) | |||
target_compile_options(atc_fe PRIVATE | |||
-Werror | |||
-fno-common | |||
-fvisibility=hidden | |||
) | |||
target_compile_definitions(atc_fe PRIVATE | |||
google=ascend_private | |||
FUNC_VISIBILITY | |||
) | |||
target_include_directories(atc_fe PRIVATE | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge_atcstub | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
#### blue zone #### | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
) | |||
target_link_options(atc_fe PRIVATE | |||
-Wl,-Bsymbolic | |||
) | |||
target_link_libraries(atc_fe PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
c_sec | |||
graph | |||
slog | |||
-Wl,--as-needed | |||
) | |||
set_target_properties(atc_fe PROPERTIES | |||
OUTPUT_NAME fe | |||
LIBRARY_OUTPUT_DIRECTORY atclib | |||
) | |||
############ libhost_cpu_opskernel_builder.so ############ | |||
add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||
target_compile_options(host_cpu_opskernel_builder PRIVATE | |||
-Werror | |||
-fno-common | |||
-fvisibility=hidden | |||
) | |||
target_compile_definitions(host_cpu_opskernel_builder PRIVATE | |||
google=ascend_private | |||
FUNC_VISIBILITY | |||
) | |||
target_include_directories(host_cpu_opskernel_builder PRIVATE | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
#### blue zone #### | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
) | |||
target_link_options(host_cpu_opskernel_builder PRIVATE | |||
-Wl,-Bsymbolic | |||
) | |||
target_link_libraries(host_cpu_opskernel_builder PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
c_sec | |||
slog | |||
graph | |||
register | |||
-Wl,--as-needed | |||
) | |||
############ atclib/libhost_cpu_opskernel_builder.so ############ | |||
add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST}) | |||
target_compile_options(atc_host_cpu_opskernel_builder PRIVATE | |||
-Werror | |||
-fno-common | |||
-fvisibility=hidden | |||
) | |||
target_compile_definitions(atc_host_cpu_opskernel_builder PRIVATE | |||
google=ascend_private | |||
FUNC_VISIBILITY | |||
) | |||
target_include_directories(atc_host_cpu_opskernel_builder PRIVATE | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
#### blue zone #### | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
) | |||
target_link_options(atc_host_cpu_opskernel_builder PRIVATE | |||
-Wl,-Bsymbolic | |||
) | |||
target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
c_sec | |||
slog | |||
graph | |||
register | |||
-Wl,--as-needed | |||
) | |||
set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES | |||
OUTPUT_NAME host_cpu_opskernel_builder | |||
LIBRARY_OUTPUT_DIRECTORY atclib | |||
) | |||
############ libhost_cpu_opskernel_builder.a ############ | |||
add_library(host_cpu_opskernel_builder_static STATIC ${CPU_OPS_KERNEL_LIST}) | |||
target_compile_options(host_cpu_opskernel_builder_static PRIVATE | |||
-Werror | |||
-fno-common | |||
-fvisibility=hidden | |||
) | |||
target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE | |||
google=ascend_private | |||
LOG_CPP | |||
FUNC_VISIBILITY | |||
) | |||
target_include_directories(host_cpu_opskernel_builder_static PRIVATE | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
#### blue zone #### | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
) | |||
target_link_libraries(host_cpu_opskernel_builder_static PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
ascend_protobuf | |||
c_sec | |||
) | |||
############ install ############ | |||
set(INSTALL_BASE_DIR "") | |||
set(INSTALL_LIBRARY_DIR lib) | |||
install(TARGETS fe host_cpu_opskernel_builder OPTIONAL | |||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||
) | |||
install(TARGETS atc_fe atc_host_cpu_opskernel_builder OPTIONAL | |||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib | |||
) |
@@ -0,0 +1,30 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||
#define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ | |||
#include <string> | |||
namespace ge { | |||
namespace host_cpu { | |||
// engine name | |||
const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU"; | |||
const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE"; | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ |
@@ -0,0 +1,74 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "stub_engine.h" | |||
#include <map> | |||
#include <memory> | |||
#include <string> | |||
#include <securec.h> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/ge/ge_util.h" | |||
#include "host_cpu_engine/common/constant/constant.h" | |||
#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||
namespace fe { | |||
AICEngine &AICEngine::Instance() { | |||
static AICEngine instance; | |||
return instance; | |||
} | |||
Status AICEngine::Initialize(const std::map<string, string> &options) { | |||
if (ops_kernel_store_ == nullptr) { | |||
ops_kernel_store_ = MakeShared<HostCpuOpsKernelInfoStore>(); | |||
if (ops_kernel_store_ == nullptr) { | |||
GELOGE(FAILED, "[Create][AICEngine]Make HostCpuOpsKernelInfoStore failed."); | |||
REPORT_INNER_ERROR("E19999", "AICEngine::Initialize failed for new AICEngine."); | |||
return FAILED; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
void AICEngine::GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||
if (ops_kernel_store_ != nullptr) { | |||
// add buildin opsKernel to opsKernelInfoMap | |||
ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_; | |||
} | |||
} | |||
void AICEngine::GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &) { | |||
// no optimizer for host cpu engine | |||
} | |||
Status AICEngine::Finalize() { | |||
ops_kernel_store_ = nullptr; | |||
return SUCCESS; | |||
} | |||
} // namespace fe | |||
ge::Status Initialize(const std::map<string, string> &options) { | |||
return fe::AICEngine::Instance().Initialize(options); | |||
} | |||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||
fe::AICEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); | |||
} | |||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers) { | |||
fe::AICEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); | |||
} | |||
ge::Status Finalize() { return fe::AICEngine::Instance().Finalize(); } |
@@ -0,0 +1,126 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||
#define GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ | |||
#if defined(_MSC_VER) | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#else | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#endif | |||
#include <map> | |||
#include <vector> | |||
#include <memory> | |||
#include <string> | |||
#include "common/opskernel/ops_kernel_info_store.h" | |||
#include "common/optimizer/graph_optimizer.h" | |||
using OpsKernelInfoStorePtr = std::shared_ptr<ge::OpsKernelInfoStore>; | |||
using GraphOptimizerPtr = std::shared_ptr<ge::GraphOptimizer>; | |||
namespace ge { | |||
namespace { | |||
std::vector<string> extern_engine_name_vec = {"fe","rts_engine","aicpu_ascend_engine","aicpu_tf_engine",} | |||
} // namespace | |||
/** | |||
* host cpu engine. | |||
* Used for the ops which executes on host. | |||
*/ | |||
class GE_FUNC_VISIBILITY StubEngine { | |||
public: | |||
/** | |||
* get HostCpuEngine instance. | |||
* @return HostCpuEngine instance. | |||
*/ | |||
static StubEngine &Instance(); | |||
virtual ~StubEngine() = default; | |||
/** | |||
* When Ge start, GE will invoke this interface | |||
* @return The status whether initialize successfully | |||
*/ | |||
Status Initialize(const std::map<string, string> &options); | |||
/** | |||
* After the initialize, GE will invoke this interface | |||
* to get the Ops kernel Store. | |||
* @param ops_kernel_map The host cpu's ops kernel info | |||
*/ | |||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||
/** | |||
* After the initialize, GE will invoke this interface | |||
* to get the Graph Optimizer. | |||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||
*/ | |||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||
/** | |||
* When the graph finished, GE will invoke this interface | |||
* @return The status whether initialize successfully | |||
*/ | |||
Status Finalize(); | |||
StubEngine(const StubEngine &StubEngine) = delete; | |||
StubEngine(const StubEngine &&StubEngine) = delete; | |||
StubEngine &operator=(const StubEngine &StubEngine) = delete; | |||
StubEngine &operator=(StubEngine &&StubEngine) = delete; | |||
private: | |||
StubEngine() = default; | |||
OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; | |||
}; | |||
} // namespace ge | |||
extern "C" { | |||
/** | |||
* When Ge start, GE will invoke this interface | |||
* @return The status whether initialize successfully | |||
*/ | |||
GE_FUNC_VISIBILITY ge::Status Initialize(const map<string, string> &options); | |||
/** | |||
* After the initialize, GE will invoke this interface to get the Ops kernel Store | |||
* @param ops_kernel_map The host cpu's ops kernel info | |||
*/ | |||
GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||
/** | |||
* After the initialize, GE will invoke this interface to get the Graph Optimizer | |||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||
*/ | |||
GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||
/** | |||
* When the graph finished, GE will invoke this interface | |||
* @return The status whether initialize successfully | |||
*/ | |||
GE_FUNC_VISIBILITY ge::Status Finalize(); | |||
} | |||
#endif // GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_ |
@@ -0,0 +1,114 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "host_cpu_ops_kernel_builder.h" | |||
#include <memory> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "ge/ge_api_types.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include <securec.h> | |||
#include "framework/common/debug/ge_log.h" | |||
#include "host_cpu_engine/common/constant/constant.h" | |||
#include "register/ops_kernel_builder_registry.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
REGISTER_OPS_KERNEL_BUILDER(kHostCpuOpKernelLibName, HostCpuOpsKernelBuilder); | |||
Status HostCpuOpsKernelBuilder::Finalize() { | |||
return SUCCESS; | |||
} | |||
Status HostCpuOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) { | |||
return SUCCESS; | |||
} | |||
Status HostCpuOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||
OpDescPtr op_desc = ge_node.GetOpDesc(); | |||
if (op_desc == nullptr) { | |||
GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); | |||
REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); | |||
return FAILED; | |||
} | |||
bool is_shape_unknown = false; | |||
if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { | |||
if (is_shape_unknown) { | |||
GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
} | |||
const string name = ge_node.GetName(); | |||
const string type = ge_node.GetType(); | |||
GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize()); | |||
for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||
GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast<uint32_t>(i)); | |||
Format format = output_tensor.GetFormat(); | |||
DataType data_type = output_tensor.GetDataType(); | |||
int64_t mem_size = 0; | |||
// If mem size has been set, no need reset. | |||
if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) { | |||
GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.", | |||
name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size); | |||
continue; | |||
} | |||
int64_t output_mem_size = 0; | |||
GeShape output_shape = output_tensor.GetShape(); | |||
if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) || | |||
(output_mem_size < 0)) { | |||
GELOGE(FAILED, | |||
"[Calc][TensorMemSize] fail for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
REPORT_CALL_ERROR("E19999", | |||
"CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return FAILED; | |||
} | |||
GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", | |||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
TensorUtils::SetSize(output_tensor, output_mem_size); | |||
if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(i), output_tensor) != GRAPH_SUCCESS) { | |||
GELOGE(FAILED, | |||
"[Update][OutputDesc] fail for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||
name.c_str(), type.c_str(), i, | |||
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
REPORT_CALL_ERROR("E19999", "UpdateOutputDesc failed for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.", | |||
name.c_str(), type.c_str(), i, | |||
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); | |||
return SUCCESS; | |||
} | |||
Status HostCpuOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) { | |||
// no need to generate device task | |||
return SUCCESS; | |||
} | |||
} // namespace host_cpu | |||
} // namespace ge |
@@ -0,0 +1,51 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||
#if defined(_MSC_VER) | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#else | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#endif | |||
#include "common/opskernel/ops_kernel_builder.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
class GE_FUNC_VISIBILITY HostCpuOpsKernelBuilder : public OpsKernelBuilder { | |||
public: | |||
Status Initialize(const map<std::string, std::string> &options) override; | |||
Status Finalize() override; | |||
Status CalcOpRunningParam(Node &node) override; | |||
Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override; | |||
}; | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ |
@@ -0,0 +1,67 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h" | |||
#include <memory> | |||
#include "common/constant/constant.h" | |||
#include "ge/ge_api_types.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "op/op_factory.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
using domi::TaskDef; | |||
using std::map; | |||
using std::string; | |||
using std::vector; | |||
Status HostCpuOpsKernelInfoStore::Initialize(const map<string, string> &options) { | |||
GELOGI("HostCpuOpsKernelInfoStore init start."); | |||
OpInfo default_op_info = {.engine = kHostCpuEngineName, | |||
.opKernelLib = kHostCpuOpKernelLibName, | |||
.computeCost = 0, | |||
.flagPartial = false, | |||
.flagAsync = false, | |||
.isAtomic = false}; | |||
// Init op_info_map_ | |||
auto all_ops = OpFactory::Instance().GetAllOps(); | |||
for (auto &op : all_ops) { | |||
op_info_map_[op] = default_op_info; | |||
} | |||
GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); | |||
return SUCCESS; | |||
} | |||
Status HostCpuOpsKernelInfoStore::Finalize() { | |||
op_info_map_.clear(); | |||
return SUCCESS; | |||
} | |||
void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | |||
bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { | |||
if (op_desc == nullptr) { | |||
return false; | |||
} | |||
return op_info_map_.count(op_desc->GetType()) > 0; | |||
} | |||
} // namespace host_cpu | |||
} // namespace ge |
@@ -0,0 +1,86 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||
#if defined(_MSC_VER) | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#else | |||
#ifdef FUNC_VISIBILITY | |||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||
#else | |||
#define GE_FUNC_VISIBILITY | |||
#endif | |||
#endif | |||
#include <map> | |||
#include <string> | |||
#include <vector> | |||
#include "common/opskernel/ops_kernel_info_store.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
class GE_FUNC_VISIBILITY HostCpuOpsKernelInfoStore : public OpsKernelInfoStore { | |||
public: | |||
HostCpuOpsKernelInfoStore() {} | |||
~HostCpuOpsKernelInfoStore() override = default; | |||
/** | |||
* Initialize related resources of the host cpu kernelinfo store | |||
* @return status whether this operation success | |||
*/ | |||
Status Initialize(const std::map<std::string, std::string> &options) override; | |||
/** | |||
* Release related resources of the host cpu kernel info store | |||
* @return status whether this operation success | |||
*/ | |||
Status Finalize() override; | |||
/** | |||
* Check to see if an operator is fully supported or partially supported. | |||
* @param op_desc OpDesc information | |||
* @param reason unsupported reason | |||
* @return bool value indicate whether the operator is fully supported | |||
*/ | |||
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; | |||
/** | |||
* Returns the full operator information. | |||
* @param infos reference of a map, | |||
* contain operator's name and detailed information | |||
*/ | |||
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override; | |||
HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||
HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||
HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete; | |||
HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete; | |||
private: | |||
// store op name and OpInfo key-value pair | |||
std::map<std::string, ge::OpInfo> op_info_map_; | |||
}; | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ |
@@ -0,0 +1,40 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "host_cpu_engine/ops_kernel_store/op/host_op.h" | |||
#include "framework/common/util.h" | |||
#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
Status HostOp::Run() { | |||
// no need to generate device task | |||
return SUCCESS; | |||
} | |||
REGISTER_OP_CREATOR(NoOp, HostOp); | |||
REGISTER_OP_CREATOR(Variable, HostOp); | |||
REGISTER_OP_CREATOR(Constant, HostOp); | |||
REGISTER_OP_CREATOR(Assign, HostOp); | |||
REGISTER_OP_CREATOR(RandomUniform, HostOp); | |||
REGISTER_OP_CREATOR(Add, HostOp); | |||
REGISTER_OP_CREATOR(Mul, HostOp); | |||
REGISTER_OP_CREATOR(ConcatV2, HostOp); | |||
REGISTER_OP_CREATOR(Data, HostOp); | |||
REGISTER_OP_CREATOR(Fill, HostOp); | |||
REGISTER_OP_CREATOR(NetOutput, HostOp); | |||
} // namespace host_cpu | |||
} // namespace ge |
@@ -0,0 +1,36 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||
#include "host_cpu_engine/ops_kernel_store/op/op.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
class GE_FUNC_VISIBILITY HostOp : public Op { | |||
public: | |||
HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} | |||
~HostOp() override = default; | |||
HostOp &operator=(const HostOp &op) = delete; | |||
HostOp(const HostOp &op) = delete; | |||
Status Run() override; | |||
}; | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ |
@@ -0,0 +1,45 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||
#include <climits> | |||
#include <string> | |||
#include <vector> | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/opskernel/ops_kernel_info_types.h" | |||
#include "graph/node.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
/** | |||
* The base class for all op. | |||
*/ | |||
class GE_FUNC_VISIBILITY Op { | |||
public: | |||
Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} | |||
virtual ~Op() = default; | |||
virtual Status Run() = 0; | |||
protected: | |||
const RunContext &run_context_; | |||
const Node &node_; | |||
}; | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ |
@@ -0,0 +1,55 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "host_cpu_engine/ops_kernel_store/op/op_factory.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "graph/op_desc.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
OpFactory &OpFactory::Instance() { | |||
static OpFactory instance; | |||
return instance; | |||
} | |||
std::shared_ptr<Op> OpFactory::CreateOp(const Node &node, RunContext &run_context) { | |||
auto iter = op_creator_map_.find(node.GetType()); | |||
if (iter != op_creator_map_.end()) { | |||
return iter->second(node, run_context); | |||
} | |||
GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); | |||
return nullptr; | |||
} | |||
void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) { | |||
if (func == nullptr) { | |||
GELOGW("Func is NULL."); | |||
return; | |||
} | |||
auto iter = op_creator_map_.find(type); | |||
if (iter != op_creator_map_.end()) { | |||
GELOGW("%s creator already exist", type.c_str()); | |||
return; | |||
} | |||
op_creator_map_[type] = func; | |||
all_ops_.emplace_back(type); | |||
} | |||
} // namespace host_cpu | |||
} // namespace ge |
@@ -0,0 +1,94 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||
#include <functional> | |||
#include <map> | |||
#include <memory> | |||
#include <string> | |||
#include <vector> | |||
#include "common/ge/ge_util.h" | |||
#include "host_cpu_engine/ops_kernel_store/op/op.h" | |||
namespace ge { | |||
namespace host_cpu { | |||
using OP_CREATOR_FUNC = std::function<std::shared_ptr<Op>(const Node &, RunContext &)>; | |||
/** | |||
* manage all the op, support create op. | |||
*/ | |||
class GE_FUNC_VISIBILITY OpFactory { | |||
public: | |||
static OpFactory &Instance(); | |||
/** | |||
* @brief create Op. | |||
* @param [in] node share ptr of node | |||
* @param [in] run_context run context | |||
* @return not nullptr success | |||
* @return nullptr fail | |||
*/ | |||
std::shared_ptr<Op> CreateOp(const Node &node, RunContext &run_context); | |||
/** | |||
* @brief Register Op create function. | |||
* @param [in] type Op type | |||
* @param [in] func Op create func | |||
*/ | |||
void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func); | |||
const std::vector<std::string> &GetAllOps() const { return all_ops_; } | |||
bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); } | |||
OpFactory(const OpFactory &) = delete; | |||
OpFactory &operator=(const OpFactory &) = delete; | |||
OpFactory(OpFactory &&) = delete; | |||
OpFactory &operator=(OpFactory &&) = delete; | |||
private: | |||
OpFactory() = default; | |||
~OpFactory() = default; | |||
// the op creator function map | |||
std::map<std::string, OP_CREATOR_FUNC> op_creator_map_; | |||
std::vector<std::string> all_ops_; | |||
}; | |||
class GE_FUNC_VISIBILITY OpRegistrar { | |||
public: | |||
OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) { | |||
OpFactory::Instance().RegisterCreator(type, func); | |||
} | |||
~OpRegistrar() = default; | |||
OpRegistrar(const OpRegistrar &) = delete; | |||
OpRegistrar &operator=(const OpRegistrar &) = delete; | |||
OpRegistrar(OpRegistrar &&) = delete; | |||
OpRegistrar &operator=(OpRegistrar &&) = delete; | |||
}; | |||
#define REGISTER_OP_CREATOR(type, clazz) \ | |||
std::shared_ptr<Op> Creator_##type##Op(const Node &node, RunContext &run_context) { \ | |||
return MakeShared<clazz>(node, run_context); \ | |||
} \ | |||
OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) | |||
} // namespace host_cpu | |||
} // namespace ge | |||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ |
@@ -0,0 +1,179 @@ | |||
/* Copyright 2021. Huawei Technologies Co., Ltd. All rights reserved. | |||
* | |||
* This program is free software; you can redistribute it and/or modify | |||
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. | |||
* | |||
* This program is distributed in the hope that it will be useful, | |||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |||
* Apache License for more details at | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
*/ | |||
syntax = "proto3"; | |||
package domi; | |||
message ModelTaskDef { | |||
string version = 1; | |||
map<string, string> attr = 9; // Extended field | |||
repeated TaskDef task = 10; | |||
uint64 memory_size = 11; | |||
uint32 stream_num = 12; | |||
uint32 event_num = 13; | |||
uint64 weight_size = 14; | |||
repeated bytes op = 15; // input/output opdef in bytes | |||
uint64 base_addr = 16; // base addr | |||
uint64 weight_addr = 17; // weight addr | |||
uint32 batch_num = 18; | |||
} | |||
message TaskDef { | |||
uint32 id = 1; | |||
uint32 type = 2; | |||
uint32 stream_id = 10; | |||
uint32 event_id = 11; | |||
KernelDef kernel = 20; | |||
KernelExDef kernel_ex = 21; | |||
KernelHcclDef kernel_hccl = 25; | |||
EventExDef event_ex = 26; | |||
LogTimeStampDef log_timestamp = 28; | |||
uint32 label_id = 30; | |||
MemcpyAsyncDef memcpy_async = 31; | |||
StreamSwitchDef stream_switch = 32; | |||
StreamActiveDef stream_active = 33; | |||
bytes private_def = 34; | |||
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future | |||
StreamSwitchNDef stream_switch_n = 36; | |||
LabelSetDef label_set = 37; | |||
LabelGotoExDef label_goto_ex = 38; | |||
LabelSwitchByIndexDef label_switch_by_index = 39; | |||
KernelDefWithHandle kernel_with_handle = 40; | |||
} | |||
message KernelDef { | |||
KernelContext context = 1; | |||
string stub_func = 10; | |||
uint32 block_dim = 11; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes sm_desc = 14; | |||
bytes flowtable = 15; | |||
string so_name = 16; | |||
string kernel_name = 17; | |||
bytes kernel_ext_info = 18; | |||
uint32 kernel_ext_info_size = 19; | |||
} | |||
message KernelDefWithHandle { | |||
KernelContext context = 1; | |||
uint64 handle = 10; | |||
string dev_func = 11; | |||
uint32 block_dim = 12; | |||
uint32 args_size = 13; | |||
bytes args = 14; | |||
bytes sm_desc = 15; | |||
string original_kernel_key = 16; | |||
string node_info = 17; | |||
} | |||
message KernelContext { | |||
uint32 kernel_type = 1; | |||
uint32 op_id = 2; // OP type in CCE | |||
uint32 kernel_func_id = 3; | |||
uint32 op_index = 4; // TE/Custom operator | |||
bool is_flowtable = 5; // Identify whether args is a flowtable structure | |||
bytes args_offset = 6; // args offset information | |||
uint32 args_count = 7; // args count | |||
repeated uint32 origin_op_index = 8; | |||
} | |||
message KernelExDef { | |||
uint32 flags = 1; | |||
uint32 op_index = 4; | |||
uint32 args_size = 12; | |||
bytes args = 13; | |||
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput | |||
uint32 task_info_size = 15; | |||
bytes kernel_ext_info = 16; | |||
uint32 kernel_ext_info_size = 17; | |||
} | |||
message KernelHcclDef { | |||
uint32 op_index = 8; | |||
string hccl_type = 9; | |||
} | |||
message EventExDef { | |||
uint32 op_index = 1; | |||
uint32 event_type = 2; | |||
} | |||
message LogTimeStampDef { | |||
uint64 logid = 1; | |||
bool notify = 2; | |||
uint32 flat = 3; | |||
} | |||
message MemcpyAsyncDef { | |||
uint64 dst = 1; | |||
uint64 dst_max = 2; | |||
uint64 src = 3; | |||
uint64 count = 4; | |||
uint32 kind = 5; | |||
uint32 op_index = 6; | |||
} | |||
message StreamSwitchDef { | |||
uint32 op_index = 1; | |||
uint32 true_stream_id = 2; | |||
int64 value = 3; | |||
uint64 value_ptr = 4; | |||
uint32 data_type = 5; | |||
} | |||
message StreamActiveDef { | |||
uint32 op_index = 1; | |||
uint32 active_stream_id = 2; | |||
} | |||
message StreamSwitchNDef { | |||
uint32 op_index = 1; | |||
uint32 size = 2; | |||
repeated int64 target_value = 3; | |||
repeated uint32 true_stream_id = 4; | |||
uint32 element_size = 5; | |||
uint32 data_type = 6; | |||
} | |||
message LabelSetDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelGotoExDef { | |||
uint32 op_index = 1; | |||
uint32 label_id = 2; | |||
uint32 model_id = 3; | |||
} | |||
message LabelSwitchByIndexDef { | |||
uint32 op_index = 1; | |||
uint32 label_max = 2; | |||
} |
@@ -0,0 +1,711 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file array_ops.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ | |||
#include "graph/operator_reg.h" | |||
#include "graph/operator.h" | |||
namespace ge { | |||
/** | |||
*@brief Finds unique elements in a 1D tensor. \n | |||
*@par Inputs: | |||
*x: 1D tensor. | |||
*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper" | |||
are 0D scalars. \n | |||
*@par Attributes: | |||
*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n | |||
*@par Outputs: | |||
*@li y: "x" in the unique output "y". | |||
*@li idx: A tensor the same size as "x". The index of each value of "x". \n | |||
*@attention Constraints: | |||
*Unique runs on the Ascend AI CPU, which delivers poor performance. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Unique. | |||
*/ | |||
REG_OP(Unique) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ | |||
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) | |||
.OUTPUT(idx, TensorType({DT_INT32, DT_INT64})) | |||
.ATTR(out_idx, Type, DT_INT32) | |||
.OP_END_FACTORY_REG(Unique) | |||
/** | |||
*@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. | |||
Operator Const has the same definition as operator Constant. \n | |||
*@par Attributes: | |||
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||
*@par Outputs: | |||
*y: A constant tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Const. | |||
*/ | |||
REG_OP(Const) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(value, Tensor, Tensor()) | |||
.OP_END_FACTORY_REG(Const) | |||
/** | |||
*@brief Creates a constant tensor for training. \n | |||
*@par Attributes: | |||
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n | |||
*@par Outputs: | |||
*y: The constant tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Const. | |||
*/ | |||
REG_OP(Constant) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(value, Tensor, Tensor()) | |||
.OP_END_FACTORY_REG(Constant) | |||
/** | |||
*@brief Returns a copy of the input tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Snapshot. | |||
*/ | |||
REG_OP(Snapshot) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(Snapshot) | |||
/** | |||
*@brief Gives a guarantee to the runtime that the input tensor is a constant. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: The input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator GuaranteeConst. | |||
*/ | |||
REG_OP(GuaranteeConst) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(GuaranteeConst) | |||
/** | |||
*@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n | |||
*@par Inputs: | |||
*@li x1: A tensor of type int32 or int64. A shape. | |||
*@li x2: A tensor of the same type as "x1". The other shape. \n | |||
*@par Outputs: | |||
*y: A tensor. The broadcasted shape. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator BroadcastArgs. | |||
*/ | |||
REG_OP(BroadcastArgs) | |||
.INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||
.INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||
.OP_END_FACTORY_REG(BroadcastArgs) | |||
/** | |||
*@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*message: Will be printed in the error at the attempt to request a gradient. \n | |||
*@par Outputs: | |||
*y: The input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator PreventGradient. | |||
*/ | |||
REG_OP(PreventGradient) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(message, String, "") | |||
.OP_END_FACTORY_REG(PreventGradient) | |||
/** | |||
*@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n | |||
*@par Inputs: | |||
*@li x1: A tensor of type int32 or int64. | |||
*@li x2: A tensor of type int32 or int64. | |||
"x2" has the same type as "x1". \n | |||
*@par Outputs: | |||
*@li y1: A tensor. Reduction indices of "x1". | |||
*@li y2: A tensor. Reduction indices of "x2". \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator BroadcastGradientArgs. | |||
*/ | |||
REG_OP(BroadcastGradientArgs) | |||
.INPUT(x1, TensorType({DT_INT32, DT_INT64})) | |||
.INPUT(x2, TensorType({DT_INT32, DT_INT64})) | |||
.OUTPUT(y1, TensorType({DT_INT32, DT_INT64})) | |||
.OUTPUT(y2, TensorType({DT_INT32, DT_INT64})) | |||
.OP_END_FACTORY_REG(BroadcastGradientArgs) | |||
/** | |||
*@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped. | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: The input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator StopGradient. | |||
*/ | |||
REG_OP(StopGradient) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(StopGradient) | |||
/** | |||
*@brief Return a tensor with the same shape and contents as input. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Identity. | |||
*/ | |||
REG_OP(Identity) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(Identity) | |||
/** | |||
*@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n | |||
*@par Inputs: | |||
*x: A list of input tensors. It's a dynamic input \n | |||
*@par Outputs: | |||
*y: A list of Tensor objects, with the same length as the input tensor list. | |||
It's a dynamic output. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator IdentityN. | |||
*/ | |||
REG_OP(IdentityN) | |||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(IdentityN) | |||
/** | |||
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||
*@par Inputs: | |||
*@li x: A tensor. | |||
*@li axis: The dimension index at which to expand. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator ExpandDims. | |||
*/ | |||
REG_OP(ExpandDims) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.INPUT(axis, TensorType({DT_INT32, DT_INT64})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OP_END_FACTORY_REG(ExpandDims) | |||
/** | |||
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n | |||
*@par Inputs: | |||
*@li x: Original tensor. | |||
*@li axis: List of ints. \n | |||
*@par Outputs: | |||
*y: Reshape tensor with same data as input. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the Onnx operator Unsqueeze. | |||
*/ | |||
REG_OP(Unsqueeze) | |||
.INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL})) | |||
.ATTR(axes, ListInt, {}) | |||
.OP_END_FACTORY_REG(Unsqueeze) | |||
/** | |||
*@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n | |||
*@par Inputs: | |||
*@li x: A tensor. | |||
*@li shape: A tensor. Defines the shape of the output tensor. \n | |||
*@par Attributes: | |||
*@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0". | |||
*@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Attention: | |||
*This operator cannot be directly called by the acllopExecute API. \n | |||
*@par Third-party framework compatibility | |||
*@li Compatible with the TensorFlow operator Reshape. | |||
*@li Compatible with the Caffe operator Reshape. | |||
*/ | |||
REG_OP(Reshape) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.INPUT(shape, TensorType({DT_INT32, DT_INT64})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, | |||
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(axis, Int, 0) | |||
.ATTR(num_axes, Int, -1) | |||
.OP_END_FACTORY_REG(Reshape) | |||
/** | |||
*@brief Removes dimensions of size 1 from the shape of a tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Squeeze. | |||
*/ | |||
REG_OP(Squeeze) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.ATTR(axis, ListInt, {}) | |||
.OP_END_FACTORY_REG(Squeeze) | |||
/** | |||
*@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: A tensor. The rank of input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Rank. | |||
*/ | |||
REG_OP(Rank) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_INT32})) | |||
.OP_END_FACTORY_REG(Rank) | |||
/** | |||
*@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||
*@par Outputs: | |||
*y: A tensor. The size of the input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Size. | |||
*/ | |||
REG_OP(Size) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_INT32,DT_INT64})) | |||
.ATTR(dtype, Int, DT_INT32) | |||
.OP_END_FACTORY_REG(Size) | |||
/** | |||
*@brief Input data for other operators. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*index: Index of the input tensor.The data type must be int32 or int64. | |||
Assume that net has three data nodes, one should be set 0, another should | |||
be set 1, and the left should be set 2. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the Caffe operator Data. | |||
*/ | |||
REG_OP(Data) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.ATTR(index, Int, 0) | |||
.OP_END_FACTORY_REG(Data) | |||
/** | |||
*@brief Inserts a placeholder for a tensor that will be always fed. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*@li peerIndex: An integer type. The index of the corresponding "end" node connected to. | |||
*@li parentId: A string, used to check if the nodes are from the saved parent node. | |||
*@li parentOpType: A string. Op type of the original node. | |||
*@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n | |||
*@par Outputs: | |||
*y: The created placeholder tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator PlaceHolder. | |||
*/ | |||
REG_OP(PlaceHolder) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to | |||
.ATTR(parentId, String, "") // check if these node are from save parent node | |||
.ATTR(parentOpType, String, "") // op type of original node | |||
.ATTR(anchorIndex, Int, 0) // check if these node are from save anchor | |||
.OP_END_FACTORY_REG(PlaceHolder) | |||
/** | |||
*@brief Inserts a placeholder with default value for a tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*@li dtype: data type of tensor. | |||
*@li shape: tensor shape. \n | |||
*@par Outputs: | |||
*y: The created placeholder tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator PlaceholderWithDefault. | |||
*/ | |||
REG_OP(PlaceholderWithDefault) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.REQUIRED_ATTR(shape, ListInt) | |||
.OP_END_FACTORY_REG(PlaceholderWithDefault) | |||
/** | |||
*@brief Reads and returns the value of the input variable tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator ReadVariableOp. | |||
*/ | |||
REG_OP(ReadVariableOp) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(dtype, Int, DT_INT32) | |||
.OP_END_FACTORY_REG(ReadVariableOp) | |||
/** | |||
*@brief Mark outputs of one sub graph which partitioned by engine type. | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Attributes: | |||
*@li peerIndex: The index of the corresponding 'placeholder' node it's connected to. | |||
*@li parentOpType: Op type of original node. | |||
*@par Restrictions: | |||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||
*/ | |||
REG_OP(End) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.ATTR(peerIndex, Int, 0) | |||
.ATTR(parentOpType, String, "") | |||
.OP_END_FACTORY_REG(End) | |||
/** | |||
*@brief Operations for writing summary data, for use in analysis and visualization. | |||
*@par Inputs: | |||
* One input: | |||
*x: Collections of summary data. | |||
*@par Restrictions: | |||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||
*/ | |||
REG_OP(Summary) | |||
.INPUT(x, TensorType::ALL()) | |||
.OP_END_FACTORY_REG(Summary) | |||
/** | |||
*@brief Returns the shape of a tensor. \n | |||
*@par Inputs: | |||
*x: A tensor. \n | |||
*@par Attributes: | |||
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n | |||
*@par Outputs: | |||
*y: A tensor. The shape of the input tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Size. | |||
*/ | |||
REG_OP(Shape) | |||
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||
.ATTR(dtype, Int, DT_INT32) | |||
.OP_END_FACTORY_REG(Shape) | |||
/** | |||
*@brief Returns shape of tensors. \n | |||
*@par Inputs: | |||
*x: A list of input tensors. It's a dynamic input. \n | |||
*@par Attributes: | |||
*dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n | |||
*@par Outputs: | |||
*y: A list of tensors with the same length as the input list of tensors. | |||
It's a dynamic output. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator ShapeN. | |||
*/ | |||
REG_OP(ShapeN) | |||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64})) | |||
.ATTR(dtype, Int, DT_INT32) | |||
.OP_END_FACTORY_REG(ShapeN) | |||
/** | |||
*@brief Creates a tensor with the given "shape" and "dtype". \n | |||
*@par Inputs: | |||
*shape: The shape of the output tensor. \n | |||
*@par Attributes: | |||
*@li dtype: Optional. The data type of the output tensor. Defaults to "int32". | |||
*@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n | |||
*@par Outputs: | |||
*y: A tensor. \n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Empty. | |||
*/ | |||
REG_OP(Empty) | |||
.INPUT(shape, TensorType({DT_INT32})) | |||
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
.ATTR(dtype, Int, DT_INT32) | |||
.ATTR(init, Bool, 0) | |||
.OP_END_FACTORY_REG(Empty) | |||
/** | |||
*@brief Returns locations of nonzero / true values in a tensor. \n | |||
*@par Inputs: | |||
*Including: | |||
*x: A Tensor. Must be one of the following types: | |||
DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, | |||
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n | |||
*@par Outputs: | |||
*y: A Tensor of type DT_INT64. \n | |||
*@attention Constraints: | |||
*Where runs on the Ascend AI CPU, which delivers poor performance.\n | |||
*@par Third-party framework compatibility | |||
*Compatible with the TensorFlow operator Where. | |||
*/ | |||
REG_OP(Where) | |||
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ | |||
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_INT64})) | |||
.OP_END_FACTORY_REG(Where) | |||
/** | |||
*@brief Change the shape of output according to the attr outShape | |||
* | |||
*@par Inputs: | |||
*x: A Tensor. \n | |||
*@par Outputs: | |||
*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n | |||
*@par Attributes: | |||
*outShape: The shape of output will be inferred according to the attribute | |||
*/ | |||
REG_OP(TransShape) | |||
.INPUT(x, TensorType::ALL()) | |||
.OUTPUT(y, TensorType::ALL()) | |||
.ATTR(outShape,ListInt ,{}) | |||
.OP_END_FACTORY_REG(TransShape); | |||
/** | |||
* @brief sort_v2. | |||
* @par Inputs: | |||
* @li x: An ND tensor of type float16. | |||
* @par Attributes: | |||
* @li axis: An optional int. The dimension to sort along. This value defaults to -1. | |||
* @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False. | |||
* @par Outputs: | |||
* @li y: An ND tensor of type float16. | |||
* @attention Constraints: | |||
* @li Axis should select the last dim. | |||
* @li When the sorting data is less than 150K, it is recommended to use this tbe ops, | |||
and the descending performance is better than the ascending. | |||
* @li The upper limit of data on Ascend910 is 2000K. | |||
*/ | |||
REG_OP(SortV2) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | |||
.ATTR(axis, Int, -1) | |||
.ATTR(descending, Bool, false) | |||
.OP_END_FACTORY_REG(SortV2) | |||
/** | |||
* @brief Expand the input tensor to a compatible shape. \n | |||
* @par Inputs: | |||
* One inputs, including: | |||
* @li x: A Tensor. Must be one of the following types: | |||
* float16, float32, int32, int8 ,uint8. \n | |||
* @li shape: A Tensor to specify the shape that the input tensor expanded to. \n | |||
* @par Outputs: | |||
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||
* @par Third-party framework compatibility | |||
* Compatible with the ONNX operator Expand. | |||
*/ | |||
REG_OP(Expand) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||
.INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||
.OP_END_FACTORY_REG(Expand) | |||
/** | |||
* @brief Expand the input tensor to a compatible shape. \n | |||
* @par Inputs: | |||
* One inputs, including: | |||
* @li x: A Tensor. Must be one of the following types: | |||
* float16, float32, int32, int8 ,uint8. \n | |||
* @par Attributes: | |||
* @li shape: A required listInt to specify the shape that the input tensor expanded to. \n | |||
* @par Outputs: | |||
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n | |||
* @par Third-party framework compatibility | |||
* Compatible with the ONNX operator Expand. | |||
*/ | |||
REG_OP(ExpandD) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8})) | |||
.REQUIRED_ATTR(shape, ListInt) | |||
.OP_END_FACTORY_REG(ExpandD) | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_ |
@@ -0,0 +1,392 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file control_flow_ops.cpp | |||
* \brief | |||
*/ | |||
#include "control_flow_ops.h" | |||
#include "./util/common_shape_fns.h" | |||
#include "./util/error_util.h" | |||
#include "util/util.h" | |||
namespace ge { | |||
namespace { | |||
graphStatus MergeInferImpl(Operator& op) { | |||
TensorDesc td = op.GetOutputDesc("value_index"); | |||
TensorDesc td_y = op.GetOutputDesc("y"); | |||
td.SetShape(ge::Shape()); | |||
td.SetDataType(DT_INT32); | |||
auto ret = op.UpdateOutputDesc("value_index", td); | |||
if (ret != GRAPH_SUCCESS) { | |||
return GRAPH_FAILED; | |||
} | |||
// check N of "x" >= 1 | |||
size_t in_num = op.GetInputsSize(); | |||
if (in_num < 1) { | |||
string reason = "inputs size[" + std::to_string(in_num) + "] must be greater than or equal to 1"; | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "input", reason); | |||
return GRAPH_FAILED; | |||
} else if (in_num == 2) { | |||
// Check is loop_merge, order of InferShape: Enter->Merge->NextIteration | |||
// So when processing InferShape on Merge op, shape & datatype of NextIteration op is set as default. | |||
// Therefore, shape & datatype of Merge op should be set as the Enter op. | |||
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||
bool not_handle_flag0 = (x0_type == DT_FLOAT) && (x0_dims.size() == 0); | |||
auto x1_type = op.GetDynamicInputDesc("x", 1).GetDataType(); | |||
auto x1_dims = op.GetDynamicInputDesc("x", 1).GetShape().GetDims(); | |||
bool not_handle_flag1 = (x1_type == DT_FLOAT) && (x1_dims.size() == 0); | |||
if ((x0_type != x1_type) && (not_handle_flag0 || not_handle_flag1)) { | |||
if (not_handle_flag0) { | |||
td_y.SetShape(ge::Shape(x1_dims)); | |||
td_y.SetDataType(x1_type); | |||
} else { | |||
td_y.SetShape(ge::Shape(x0_dims)); | |||
td_y.SetDataType(x0_type); | |||
} | |||
(void)op.UpdateOutputDesc("y", td_y); | |||
return GRAPH_SUCCESS; | |||
} | |||
} | |||
// check "x" be same type | |||
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType(); | |||
for (size_t i = 1; i < op.GetInputsSize(); i++) { | |||
auto xi_type = op.GetDynamicInputDesc("x", i).GetDataType(); | |||
if (xi_type != x0_type) { | |||
string reason = "x[0]'s dtype[" + std::to_string(x0_type) + "] must be equal to x[" + std::to_string(i) + | |||
"]'s dtype[" + std::to_string(xi_type) + "]"; | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
// infer "y" be unknown shape | |||
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims(); | |||
bool x0_unknown = (x0_dims.size() == 1) && (x0_dims[0] == 0); | |||
if (x0_unknown) { | |||
Shape unknown_shape(ge::UNKNOWN_SHAPE); | |||
td_y.SetShape(unknown_shape); | |||
td_y.SetDataType(x0_type); | |||
(void)op.UpdateOutputDesc("y", td_y); | |||
return GRAPH_SUCCESS; | |||
} | |||
// find the input with the max size from all inputs, and set it's data type/shape to the output | |||
std::map<int64_t, size_t> size_to_index; | |||
for (size_t i = 0; i < op.GetInputsSize(); i++) { | |||
auto xi_dims = op.GetDynamicInputDesc("x", i).GetShape().GetDims(); | |||
bool xi_unknown = (xi_dims.size() == 1) && (xi_dims[0] == 0); | |||
if (xi_unknown) { | |||
continue; | |||
} | |||
int64_t size = static_cast<int64_t>(GetSizeByDataType(op.GetDynamicInputDesc("x", i).GetDataType())); | |||
if (size < 0) { | |||
continue; | |||
} | |||
if (!xi_dims.empty()) { | |||
for (auto& dim : xi_dims) { | |||
if (dim <= 0) { | |||
size = -1; | |||
break; | |||
} | |||
if (size != 0 && INT64_MAX / size < dim) { | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dim", "the dim size is overflow"); | |||
return GRAPH_FAILED; | |||
} | |||
size *= dim; | |||
} | |||
if (size < 0) { | |||
continue; | |||
} | |||
} | |||
if (size_to_index.count(size) == 0) { | |||
size_to_index[size] = i; | |||
} | |||
} | |||
if (size_to_index.empty()) { | |||
return GRAPH_FAILED; | |||
} | |||
auto index = size_to_index.rbegin()->second; | |||
td_y.SetShape(ge::Shape(op.GetDynamicInputDesc("x", index).GetShape().GetDims())); | |||
td_y.SetDataType(op.GetDynamicInputDesc("x", index).GetDataType()); | |||
(void)op.UpdateOutputDesc("y", td_y); | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus SwitchInferImpl(Operator& op) { | |||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||
auto data_desc = op_desc->MutableInputDesc("data"); | |||
auto pred_desc = op_desc->MutableInputDesc("pred"); | |||
auto output_false_desc = op_desc->MutableOutputDesc("output_false"); | |||
auto output_true_desc = op_desc->MutableOutputDesc("output_true"); | |||
std::vector<std::pair<int64_t, int64_t>> data_range; | |||
data_desc->GetShapeRange(data_range); | |||
// check "pred" scalar type be bool | |||
auto pred_dims = pred_desc->GetShape().GetDims(); | |||
if (pred_dims.size() != 0) { | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "pred dims", "pred should be a scalar"); | |||
return GRAPH_FAILED; | |||
} | |||
DataType pred_type = pred_desc->GetDataType(); | |||
if (pred_type != DT_BOOL) { | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "pred should be bool type"); | |||
return GRAPH_FAILED; | |||
} | |||
DataType data_type = data_desc->GetDataType(); | |||
auto data_dims = data_desc->GetShape().GetDims(); | |||
output_false_desc->SetShapeRange(data_range); | |||
output_true_desc->SetShapeRange(data_range); | |||
output_false_desc->SetShape(GeShape(data_dims)); | |||
output_false_desc->SetOriginShape(GeShape(data_dims)); | |||
output_true_desc->SetShape(GeShape(data_dims)); | |||
output_true_desc->SetOriginShape(GeShape(data_dims)); | |||
output_false_desc->SetDataType(data_type); | |||
output_true_desc->SetDataType(data_type); | |||
auto context = op.GetInferenceContext(); | |||
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||
std::vector<ShapeAndType> grad_handle_shape_and_type; | |||
grad_handle_shape_and_type.reserve(1); | |||
grad_handle_shape_and_type.emplace_back(shape_and_type); | |||
std::vector<std::vector<ShapeAndType>> shapes_and_types(2); | |||
shapes_and_types[0] = grad_handle_shape_and_type; | |||
shapes_and_types[1] = grad_handle_shape_and_type; | |||
context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus EnterInferImpl(Operator& op) { | |||
auto op_desc = OpDescUtils::GetOpDescFromOperator(op); | |||
auto input_desc_x = op_desc->MutableInputDesc("x"); | |||
auto output_desc_y = op_desc->MutableOutputDesc("y"); | |||
std::vector<std::pair<int64_t, int64_t>> x_range; | |||
std::vector<std::pair<int64_t, int64_t>> y_range; | |||
input_desc_x->GetShapeRange(x_range); | |||
auto input_dims = input_desc_x->MutableShape().GetDims(); | |||
DataType input_type = input_desc_x->GetDataType(); | |||
output_desc_y->SetShape(ge::GeShape(input_dims)); | |||
output_desc_y->SetOriginShape(ge::GeShape(input_dims)); | |||
output_desc_y->SetDataType(input_type); | |||
if (!x_range.empty()) { | |||
output_desc_y->SetShapeRange(x_range); | |||
} | |||
auto context = op.GetInferenceContext(); | |||
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes(); | |||
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) { | |||
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0); | |||
std::vector<ShapeAndType> grad_handle_shape_and_type; | |||
grad_handle_shape_and_type.reserve(1); | |||
grad_handle_shape_and_type.emplace_back(shape_and_type); | |||
std::vector<std::vector<ShapeAndType>> shapes_and_types(1); | |||
shapes_and_types[0] = grad_handle_shape_and_type; | |||
context->SetOutputHandleShapesAndTypes(shapes_and_types); | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus PassThroughInferImpl(Operator& op, const std::string& in_name, const std::string& out_name) { | |||
auto input_dims = op.GetInputDesc(in_name).GetShape().GetDims(); | |||
DataType input_type = op.GetInputDesc(in_name).GetDataType(); | |||
TensorDesc tensordesc_output = op.GetOutputDesc(out_name); | |||
tensordesc_output.SetShape(ge::Shape(input_dims)); | |||
tensordesc_output.SetDataType(input_type); | |||
(void)op.UpdateOutputDesc(out_name, tensordesc_output); | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus LoopCondInferImpl(Operator& op) { | |||
auto input_dims = op.GetInputDesc("x").GetShape().GetDims(); | |||
if (input_dims.size() != 0) { | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "x dims", "x should be a scalar"); | |||
return GRAPH_FAILED; | |||
} | |||
TensorDesc tensordesc_output = op.GetOutputDesc("y"); | |||
tensordesc_output.SetShape(ge::Shape(input_dims)); | |||
DataType input_type = op.GetInputDesc("x").GetDataType(); | |||
if (input_type != DT_BOOL) { | |||
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "x should be bool type"); | |||
return GRAPH_FAILED; | |||
} | |||
tensordesc_output.SetDataType(input_type); | |||
(void)op.UpdateOutputDesc("y", tensordesc_output); | |||
return GRAPH_SUCCESS; | |||
} | |||
} // namespace | |||
IMPLEMT_INFERFUNC(Merge, MergeInfer) { | |||
return MergeInferImpl(op); | |||
} | |||
INFER_FUNC_REG(Merge, MergeInfer); | |||
IMPLEMT_INFERFUNC(RefMerge, RefMergeInfer) { | |||
return MergeInferImpl(op); | |||
} | |||
INFER_FUNC_REG(RefMerge, RefMergeInfer); | |||
IMPLEMT_INFERFUNC(Switch, SwitchInfer) { | |||
return SwitchInferImpl(op); | |||
} | |||
INFER_FUNC_REG(Switch, SwitchInfer); | |||
IMPLEMT_INFERFUNC(RefSwitch, RefSwitchInfer) { | |||
return SwitchInferImpl(op); | |||
} | |||
INFER_FUNC_REG(RefSwitch, RefSwitchInfer); | |||
IMPLEMT_INFERFUNC(SwitchN, SwitchNInfer) { | |||
return GRAPH_SUCCESS; | |||
} | |||
INFER_FUNC_REG(SwitchN, SwitchNInfer); | |||
IMPLEMT_INFERFUNC(Enter, EnterInfer) { | |||
return EnterInferImpl(op); | |||
} | |||
INFER_FUNC_REG(Enter, EnterInfer); | |||
IMPLEMT_INFERFUNC(RefEnter, RefEnterInfer) { | |||
return PassThroughInferImpl(op, "x", "y"); | |||
} | |||
INFER_FUNC_REG(RefEnter, RefEnterInfer); | |||
IMPLEMT_INFERFUNC(LoopCond, LoopCondInfer) { | |||
return LoopCondInferImpl(op); | |||
} | |||
INFER_FUNC_REG(LoopCond, LoopCondInfer); | |||
IMPLEMT_INFERFUNC(NextIteration, NextIterationInfer) { | |||
return PassThroughInferImpl(op, "x", "y"); | |||
} | |||
INFER_FUNC_REG(NextIteration, NextIterationInfer); | |||
IMPLEMT_INFERFUNC(RefNextIteration, RefNextIterationInfer) { | |||
return PassThroughInferImpl(op, "x", "y"); | |||
} | |||
INFER_FUNC_REG(RefNextIteration, RefNextIterationInfer); | |||
IMPLEMT_INFERFUNC(Exit, ExitInfer) { | |||
return PassThroughInferImpl(op, "x", "y"); | |||
} | |||
INFER_FUNC_REG(Exit, ExitInfer); | |||
IMPLEMT_INFERFUNC(RefExit, RefExitInfer) { | |||
return PassThroughInferImpl(op, "x", "y"); | |||
} | |||
INFER_FUNC_REG(RefExit, RefExitInfer); | |||
// ----------------MapIndex------------------- | |||
IMPLEMT_VERIFIER(MapIndex, MapIndexVerify) { | |||
return GRAPH_SUCCESS; | |||
} | |||
IMPLEMT_COMMON_INFERFUNC(MapIndexInferShape) { | |||
OP_LOGI("MapIndex", "infer shape begin---"); | |||
auto x_shape = op.GetInputDesc("x").GetShape().GetDims(); | |||
if (x_shape.empty()) { | |||
OP_LOGE(op.GetName().c_str(), "x_shape is empty"); | |||
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "x_shape is empty"); | |||
return GRAPH_FAILED; | |||
} | |||
int64_t x_length = x_shape[0]; | |||
auto data_seq_shape = op.GetInputDesc("data_seq").GetShape().GetDims(); | |||
if (data_seq_shape.empty()) { | |||
OP_LOGE(op.GetName().c_str(), "data_seq_shape is empty"); | |||
OpsOneInputShapeErrReport(op.GetName().c_str(), "data_seq", "data_seq_shape is empty"); | |||
return GRAPH_FAILED; | |||
} | |||
int64_t data_seq_length = data_seq_shape[0]; | |||
if (x_length > 8 || x_length == 0) { | |||
OP_LOGE(op.GetName().c_str(), "the length of x should be less than or equal to 8"); | |||
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "the length of x should be less than or equal to 8 and not 0"); | |||
return GRAPH_FAILED; | |||
} | |||
if (data_seq_length % x_length != 0) { | |||
OP_LOGE(op.GetName().c_str(), "the length of data_seq must be multiple of the length of x"); | |||
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||
"the length of data_seq must be multiple of the length of x"); | |||
return GRAPH_FAILED; | |||
} | |||
if (data_seq_length / x_length > 100) { | |||
OP_LOGE(op.GetName().c_str(), "data_seq_length / x_length should be be less than or equal to 100"); | |||
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x", | |||
"data_seq_length / x_length should be be less than or equal to 100"); | |||
return GRAPH_FAILED; | |||
} | |||
auto level_index_shape = op.GetInputDesc("level_index").GetShape().GetDims(); | |||
if (!level_index_shape.empty()) { | |||
int64_t level_index_length = level_index_shape[0]; | |||
if (level_index_length != (data_seq_length / x_length)) { | |||
OP_LOGE(op.GetName().c_str(), | |||
"the length of level_index must be equal to " | |||
"the length of data_seq divided by the length of x"); | |||
OpsOneInputShapeErrReport(op.GetName().c_str(), "level_index", | |||
"the length of level_index must be equal to " | |||
"the length of data_seq divided by the length of x"); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
TensorDesc y_desc = op.GetOutputDesc("y"); | |||
y_desc.SetShape(ge::Shape()); | |||
y_desc.SetDataType(ge::DT_INT32); | |||
(void)op.UpdateOutputDesc("y", y_desc); | |||
return GRAPH_SUCCESS; | |||
} | |||
COMMON_INFER_FUNC_REG(MapIndex, MapIndexInferShape); | |||
VERIFY_FUNC_REG(MapIndex, MapIndexVerify); | |||
} // namespace ge |
@@ -0,0 +1,407 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file control_flow_ops.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ | |||
#include "graph/operator_reg.h" | |||
#include "graph/operator.h" | |||
namespace ge { | |||
/** | |||
*@brief Forwards the value of an available tensor from input "x" to output "y". | |||
* Merge waits for at least one of the input tensors to become available. | |||
* It is usually combined with Switch to implement branching. | |||
* Merge forwards the first tensor to become available to output "y", | |||
* and sets "value_index" the index of the tensor in inputs . \n | |||
*@par Inputs: | |||
*x: The input tensors, one of which will become available. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||
*@par Outputs: | |||
*@li y: The available tensor. Has the same type as "x". | |||
*@li value_index: A scalar of type int32, for the index of the chosen input | |||
* tensor . \n | |||
*@see Switch() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator Merge. | |||
*/ | |||
REG_OP(Merge) | |||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(value_index, TensorType({DT_INT32})) | |||
.OP_END_FACTORY_REG(Merge) | |||
/** | |||
*@brief Forwards the value of an available tensor from input "x" to output "y". | |||
* Merge waits for at least one of the input tensors to become available. | |||
* It is usually combined with Switch to implement branching. | |||
* Merge forwards the first tensor to become available to output "y", | |||
* and sets "value_index" the index of the tensor in inputs . \n | |||
*@par Inputs: | |||
*x: The input tensors, one of which will become available. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n | |||
*@par Outputs: | |||
*@li y: The available tensor. Has the same type as "x". | |||
*@li value_index: A scalar of type int32, for the index of the chosen input | |||
* tensor . \n | |||
*@see Switch() | Merge() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator RefMerge. | |||
*/ | |||
REG_OP(RefMerge) | |||
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(value_index, TensorType({DT_INT32})) | |||
.OP_END_FACTORY_REG(RefMerge) | |||
/** | |||
*@brief Forwards "data" to the output port determined by "pred". | |||
* If "pred" is "true", the data input is forwarded to "output_true". | |||
* Otherwise, the data is forwarded to "output_false" . \n | |||
*@par Inputs: | |||
*@li data: The tensor to be forwarded. \ n | |||
* Must be one of the following types: float16, float32, float64, | |||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||
*@li pred: A boolean scalar. The output port that will receive data . \n | |||
*@par Outputs: | |||
*@li output_false: If "pred" is "false", data will be forwarded to this output. | |||
* Has the same type as "data". | |||
*@li output_true: If "pred" is "true", data will be forwarded to this output. | |||
* Has the same type as "data" . \n | |||
*@see Merge() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator Switch. | |||
*/ | |||
REG_OP(Switch) | |||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.INPUT(pred, TensorType({DT_BOOL})) | |||
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(Switch) | |||
/** | |||
*@brief Forwards "data" to the output port determined by "pred". | |||
* If "pred" is "true", the data input is forwarded to "output_true". | |||
* Otherwise, the data is forwarded to "output_false" . \n | |||
*@par Inputs: | |||
*@li data: The ref tensor to be forwarded. | |||
* Must be one of the following types: float16, float32, float64, | |||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||
*@li pred: A boolean scalar. The output port that will receive data . \n | |||
*@par Outputs: | |||
*@li output_false: If "pred" is "false", data will be forwarded to this output. | |||
* Has the same type as "data". | |||
*@li output_true: If "pred" is "true", data will be forwarded to this output. | |||
* Has the same type as "data" . \n | |||
*@see Merge() | Switch() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator RefSwitch. | |||
*/ | |||
REG_OP(RefSwitch) | |||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.INPUT(pred, TensorType({DT_BOOL})) | |||
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(RefSwitch) | |||
/** | |||
*@brief Forwards "data" to the output port determined by "pred_value" . \n | |||
*@par Inputs: | |||
*@li data: The tensor to be forwarded. \ n | |||
* Must be one of the following types: float16, float32, float64, | |||
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool. | |||
*@li pred_value: A int64 tensor which determines the output port that will receive data . \n | |||
*@par Outputs: | |||
*output: The output tensors, one of which will become available. | |||
* Has the same type as "data". | |||
*/ | |||
REG_OP(SwitchN) | |||
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.INPUT(pred_value, TensorType({DT_INT64})) | |||
.DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(SwitchN) | |||
/** | |||
*@brief Creates or finds a child frame, and makes "x" available to the child | |||
* frame. This op is used together with Exit to create loops in the graph. | |||
* The Executor uses the unique "frame_name" to identify frames. | |||
* If "is_constant" is "true", output "y" is a constant in the child | |||
* frame; otherwise it may be changed in the child frame . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the child frame. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Attributes: | |||
*@li frame_name: A required string. The name of the child frame. | |||
*@li is_constant: A required bool. If true, the output is constant in | |||
* the child frame . \n | |||
*@par Outputs: | |||
*y: A Tensor. Has the same type as "x" . \n | |||
*@see Exit() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator Enter. | |||
*/ | |||
REG_OP(Enter) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.REQUIRED_ATTR(frame_name, String) | |||
.REQUIRED_ATTR(is_constant, Bool) | |||
.OP_END_FACTORY_REG(Enter) | |||
/** | |||
*@brief Creates or finds a child frame, and makes "x" available to the child | |||
* frame. This op is used together with Exit to create loops in the graph. | |||
* The Executor uses the unique "frame_name" to identify frames. | |||
* If "is_constant" is "true", output "y" is a constant in the child | |||
* frame; otherwise it may be changed in the child frame . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the child frame. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Attributes: | |||
*@li frame_name: A required string. The name of the child frame. | |||
*@li is_constant: A required bool. If true, the output is constant in | |||
* the child frame . \n | |||
*@par Outputs: | |||
*y: A tensor. Has the same type as "x" . \n | |||
*@see Exit() | Enter() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator RefEnter. | |||
*/ | |||
REG_OP(RefEnter) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.REQUIRED_ATTR(frame_name, String) | |||
.REQUIRED_ATTR(is_constant, Bool) | |||
.OP_END_FACTORY_REG(RefEnter) | |||
/** | |||
*@brief Forwards the input to the output. This op represents the loop | |||
* termination condition . \n | |||
*@par Inputs: | |||
*x: A boolean scalar. The condition of the Switch op . \n | |||
*@par Outputs: | |||
*y: The tensor "x" . \n | |||
*@see Switch() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator LoopCond. | |||
*/ | |||
REG_OP(LoopCond) | |||
.INPUT(x, TensorType({DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_BOOL})) | |||
.OP_END_FACTORY_REG(LoopCond) | |||
/** | |||
*@brief Makes the input available to the next iteration . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the next iteration. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Outputs: | |||
*y: A Tensor. Has the same type as "x" . \n | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator NextIteration. | |||
*/ | |||
REG_OP(NextIteration) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(NextIteration) | |||
/** | |||
*@brief Makes the input available to the next iteration . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the next iteration. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Outputs: | |||
*y: A tensor. Has the same type as "x" . \n | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator RefNextIteration. | |||
*/ | |||
REG_OP(RefNextIteration) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(RefNextIteration) | |||
/** | |||
*@brief Exits the current frame to its parent frame . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the parent frame. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Outputs: | |||
*y: A Tensor. Has the same type as "x" . \n | |||
*@see Enter() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator Exit. | |||
*/ | |||
REG_OP(Exit) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(Exit) | |||
/** | |||
*@brief Exits the current frame to its parent frame . \n | |||
*@par Inputs: | |||
*x: The tensor to be made available to the parent frame. | |||
* Must be one of the following types: float16, float32, float64, int8, | |||
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n | |||
*@par Outputs: | |||
*y: A tensor. Has the same type as "x" . \n | |||
*@see Enter() | Exit() | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator RefExit. | |||
*/ | |||
REG_OP(RefExit) | |||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, | |||
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, | |||
DT_UINT64, DT_BOOL})) | |||
.OP_END_FACTORY_REG(RefExit) | |||
/** | |||
*@brief Only useful as a placeholder for control edges. | |||
* It is similar to a no-op that always produces a live control output | |||
* even when some control inputs are dead . \n | |||
*@par Third-party framework compatibility | |||
*@Compatible with the TensorFlow operator ControlTrigger. | |||
*/ | |||
REG_OP(ControlTrigger) | |||
.OP_END_FACTORY_REG(ControlTrigger) | |||
/** | |||
*@brief Returns index of shape in the map. | |||
*@par Inputs: | |||
* Three inputs, including: | |||
*@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8. | |||
*@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried. | |||
*@li level_index: One dimensional tensore of type int32, specifying secondary index. \n | |||
*@par Outputs: | |||
*@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map. | |||
*@par Third-party framework compatibility | |||
* It is a custom operator. It has no corresponding operator in Caffe. | |||
*/ | |||
REG_OP(MapIndex) | |||
.INPUT(x, TensorType({DT_INT32})) | |||
.INPUT(data_seq, TensorType({DT_INT32})) | |||
.OPTIONAL_INPUT(level_index, TensorType({DT_INT32})) | |||
.OUTPUT(y, TensorType({DT_INT32})) | |||
.OP_END_FACTORY_REG(MapIndex) | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_ |
@@ -0,0 +1,234 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file array_ops_shape_fns.cpp | |||
* \brief | |||
*/ | |||
#include "array_ops_shape_fns.h" | |||
#include "graph/types.h" | |||
#include "op_log.h" | |||
#include "error_util.h" | |||
#include "common_shape_fns.h" | |||
#include "axis_util.h" | |||
namespace ge { | |||
static graphStatus PadKnown(Operator& op, const Tensor& paddings_tensor, const int64_t input_dim_num) { | |||
TensorDesc paddings_tensor_desc = paddings_tensor.GetTensorDesc(); | |||
DataType data_type = paddings_tensor_desc.GetDataType(); | |||
std::vector<int64_t> data; | |||
// every dim has 2 element | |||
int64_t element_num = input_dim_num * 2; | |||
data.reserve(element_num); | |||
if (data_type == DT_INT32) { | |||
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < element_num, | |||
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||
for (int64_t i = 0; i < element_num; ++i) { | |||
data.push_back(static_cast<int64_t>(paddings_data[i])); | |||
} | |||
} else if (data_type == DT_INT64) { | |||
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < element_num, | |||
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED); | |||
for (int64_t i = 0; i < element_num; ++i) { | |||
data.push_back(paddings_data[i]); | |||
} | |||
} else { | |||
string err_msg = ConcatString("paddings data type invalid, ", "should be DT_INT32 or DT_INT64"); | |||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
auto dims = op.GetInputDesc(0).GetShape().GetDims(); | |||
std::vector<int64_t> output_dims(input_dim_num, UNKNOWN_DIM); | |||
if (dims != UNKNOWN_SHAPE) { | |||
output_dims.assign(dims.begin(), dims.end()); | |||
} | |||
for (size_t i = 0; i < data.size(); i += 2) { | |||
if ((data[i] < 0) || (data[i + 1] < 0)) { | |||
std::string err_msg = ConcatString("paddings", DebugString(data), " must be non-negative"); | |||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
graphStatus status = Add(output_dims[i / 2], data[i] + data[i + 1], output_dims[i / 2]); | |||
if (status != GRAPH_SUCCESS) { | |||
std::string err_msg = ConcatString("the sum input[0] shape", DebugString(dims), " and input[1] value", | |||
DebugString(data), " must be non-negative"); | |||
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str()); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
auto output_desc = op.GetOutputDesc("y"); | |||
output_desc.SetShape(Shape(output_dims)); | |||
return op.UpdateOutputDesc("y", output_desc); | |||
} | |||
graphStatus PadShapeFn(Operator& op) { | |||
Shape paddings; | |||
int64_t input_dim_num; | |||
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||
if (status != GRAPH_SUCCESS) { | |||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||
return GRAPH_FAILED; | |||
} | |||
status = WithValue(paddings.GetDim(1), 2, input_dim_num, op.GetName().c_str()); | |||
if (status != GRAPH_SUCCESS) { | |||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), | |||
ConcatString(2, " of dim[1]")); | |||
return GRAPH_FAILED; | |||
} | |||
Shape input; | |||
int64_t dim0 = paddings.GetDim(0); | |||
if (dim0 != UNKNOWN_DIM) { | |||
status = WithRank(op.GetInputDesc(0), dim0, input, op.GetName().c_str()); | |||
if (status != GRAPH_SUCCESS) { | |||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||
return GRAPH_FAILED; | |||
} | |||
} else if (op.GetInputDesc(0).GetShape().GetDim(0) != 0) { | |||
status = WithValue(dim0, op.GetInputDesc(0).GetShape().GetDimNum(), input_dim_num, op.GetName().c_str()); | |||
if (status != GRAPH_SUCCESS) { | |||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D")); | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
TensorDesc output_desc = op.GetOutputDesc("y"); | |||
Tensor paddings_tensor; | |||
status = op.GetInputConstData("paddings", paddings_tensor); | |||
if (status != GRAPH_SUCCESS) { | |||
if (dim0 != UNKNOWN_DIM) { | |||
std::vector<int64_t> output_shape(dim0, UNKNOWN_DIM); | |||
output_desc.SetShape(Shape(output_shape)); | |||
} else { | |||
output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||
} | |||
return op.UpdateOutputDesc("y", output_desc); | |||
} | |||
input_dim_num = paddings_tensor.GetTensorDesc().GetShape().GetDim(0); | |||
status = WithRank(op.GetInputDesc(0), input_dim_num, input, op.GetName().c_str()); | |||
if (status == GRAPH_FAILED) { | |||
OP_LOGE(op.GetName().c_str(), "WithRank fail"); | |||
return GRAPH_FAILED; | |||
} | |||
status = WithValue(dim0, input_dim_num, dim0, op.GetName().c_str()); | |||
if (status == GRAPH_FAILED) { | |||
OP_LOGE(op.GetName().c_str(), "WithValue fail"); | |||
return GRAPH_FAILED; | |||
} | |||
return PadKnown(op, paddings_tensor, input_dim_num); | |||
} | |||
static graphStatus CalcPadGradOutDims(const Shape& input_shape, const Tensor& paddings_tensor, | |||
std::vector<int64_t>& output_dims, const char* op_name) { | |||
graphStatus status; | |||
size_t input_rank = input_shape.GetDimNum(); | |||
if (output_dims.size() < input_rank) { | |||
return GRAPH_FAILED; | |||
} | |||
DataType padding_type = paddings_tensor.GetTensorDesc().GetDataType(); | |||
if (padding_type == DT_INT32) { | |||
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData()); | |||
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < input_rank, | |||
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||
for (size_t i = 0; i < input_rank; ++i) { | |||
const int64_t pad0 = static_cast<int64_t>(paddings_data[2 * i]); | |||
const int64_t pad1 = static_cast<int64_t>(paddings_data[(2 * i) + 1]); | |||
if ((pad0 < 0) || (pad1 < 0)) { | |||
OP_LOGE(op_name, "Paddings must be non-negative, pad0= %lld, pad1=%lld.", pad0, pad1); | |||
return GRAPH_FAILED; | |||
} | |||
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||
if (status != GRAPH_SUCCESS) { | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
} else if (padding_type == DT_INT64) { | |||
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData()); | |||
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < input_rank, | |||
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED); | |||
for (size_t i = 0; i < input_rank; ++i) { | |||
const int64_t pad0 = paddings_data[2 * i]; | |||
const int64_t pad1 = paddings_data[(2 * i) + 1]; | |||
if ((pad0 < 0) || (pad1 < 0)) { | |||
OP_LOGE(op_name, "Paddings must be non-negative, pad0=%lld, pad1=%lld.", pad0, pad1); | |||
return GRAPH_FAILED; | |||
} | |||
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name); | |||
if (status != GRAPH_SUCCESS) { | |||
return GRAPH_FAILED; | |||
} | |||
} | |||
} else { | |||
OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64"); | |||
return GRAPH_FAILED; | |||
} | |||
return GRAPH_SUCCESS; | |||
} | |||
graphStatus PadGradShapeFn(Operator& op) { | |||
Shape paddings; | |||
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str()); | |||
if (status != GRAPH_SUCCESS) { | |||
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D"); | |||
return GRAPH_FAILED; | |||
} | |||
int64_t input_rank = paddings.GetDim(0); | |||
TensorDesc output_desc = op.GetOutputDesc("y"); | |||
output_desc.SetDataType(op.GetInputDesc(0).GetDataType()); | |||
if (input_rank == UNKNOWN_DIM) { | |||
OP_LOGE(op.GetName().c_str(), "paddings inputShape of 0 dims is unknown, set out shape unknown."); | |||
output_desc.SetShape(Shape(UNKNOWN_SHAPE)); | |||
return op.UpdateOutputDesc("y", output_desc); | |||
} | |||
Shape input_shape; | |||
if (WithRank(op.GetInputDesc(0), input_rank, input_shape, op.GetName().c_str()) != GRAPH_SUCCESS) { | |||
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(input_rank)); | |||
return GRAPH_FAILED; | |||
} | |||
Shape check_shape({input_rank, 2}); | |||
if (Merge(paddings, check_shape, paddings, op.GetName().c_str())) { | |||
string err_msg = ConcatString("merge 1th input shape", DebugString(paddings.GetDims()), " and shape", | |||
DebugString(check_shape.GetDims()), " failed"); | |||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||
OP_LOGE(op.GetName().c_str(), "Input dimension mismatch, inputRank=%lld.", input_rank); | |||
return GRAPH_FAILED; | |||
} | |||
Tensor paddings_tensor; | |||
if (op.GetInputConstData("paddings", paddings_tensor) != GRAPH_SUCCESS) { | |||
std::vector<int64_t> unknow_dim_vec(input_rank, UNKNOWN_DIM); | |||
OP_LOGE(op.GetName().c_str(), "Get paddings input tensor fail, set outPut shape unknown."); | |||
output_desc.SetShape(Shape(unknow_dim_vec)); | |||
return op.UpdateOutputDesc("y", output_desc); | |||
} | |||
std::vector<int64_t> output_dims(input_rank); | |||
auto result = CalcPadGradOutDims(input_shape, paddings_tensor, output_dims, op.GetName().c_str()); | |||
if (result != GRAPH_SUCCESS) { | |||
string err_msg = ConcatString("calculate out dims failed,", "please check the validity of input and attribute"); | |||
InferShapeOtherErrReport(op.GetName(), err_msg); | |||
OP_LOGE(op.GetName().c_str(), "Calculation PadGrad out dimensions failed."); | |||
return GRAPH_FAILED; | |||
} | |||
output_desc.SetShape(Shape(output_dims)); | |||
return op.UpdateOutputDesc("y", output_desc); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,42 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file array_ops_shape_fns.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ | |||
#include "graph/operator.h" | |||
namespace ge { | |||
/* * | |||
* infer pad op shape | |||
* @param op Operator which need to infershape | |||
* @return status whether infershape success | |||
*/ | |||
graphStatus PadShapeFn(Operator& op); | |||
/* * | |||
* infer pad grad op shape | |||
* @param op Operator which need to infershape | |||
* @return status whether infershape success | |||
*/ | |||
graphStatus PadGradShapeFn(Operator& op); | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_ |
@@ -0,0 +1,195 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file axis_util.cpp | |||
* \brief get the axis value | |||
*/ | |||
#include "axis_util.h" | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "framework/common/types.h" | |||
namespace ge { | |||
AxisUtil::AxisUtil() { | |||
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)}, | |||
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)}, | |||
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)}, | |||
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)}, | |||
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)}, | |||
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}}; | |||
} | |||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { | |||
if (divisor == 0) { | |||
return 0; | |||
} else { | |||
return (dividend + divisor - 1) / divisor; | |||
} | |||
} | |||
bool AxisUtil::GetAxisValueByOriginFormat(const Format& format, const vector<int64_t>& dimVec, const uint32_t& c0, | |||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||
LOG_INFO("Can not get axis value of old format %u!", format); | |||
return false; | |||
} | |||
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; | |||
CHECK_NOTNULL(getAxisFunc); | |||
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); | |||
} | |||
bool AxisUtil::HasAxisValueFunc(const Format& format) { | |||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||
LOG_INFO("Can not get axis value of format %u!", format); | |||
return false; | |||
} | |||
return true; | |||
} | |||
bool AxisUtil::CheckParams(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||
vector<int64_t>& ndValue) { | |||
ndValue = originalDimVec; | |||
auto dimSize = originalDimVec.size(); | |||
if (dimSize < ge::DIM_DEFAULT_SIZE) { | |||
/* Before this funcion, we should call function PadDimensionTo4. */ | |||
LOG_INFO("Dimension size %zu is invalid.", dimSize); | |||
return false; | |||
} | |||
if (c0 == 0) { | |||
LOG_ERROR("[ERROR]c0 is zero!"); | |||
return false; | |||
} | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByND(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||
vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
ndValue = originalDimVec; | |||
/* To differentiate the input datatype of int8 and others */ | |||
axisValue[AXIS_C0] = c0; | |||
if (originalDimVec.size() == NCHW_DIMENSION_NUM) { | |||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||
axisValue[AXIS_Co] = c0; | |||
} | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||
vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
/* C0 Must be set for case ND or 2D-NCHW to NZ */ | |||
axisValue[AXIS_C0] = c0; | |||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||
return false); | |||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||
axisValue[AXIS_Co] = c0; | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||
vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||
axisValue[AXIS_C0] = c0; | |||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||
return false); | |||
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; | |||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); | |||
axisValue[AXIS_Co] = c0; | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||
return false); | |||
auto dimSize = originalDimVec.size(); | |||
if (dimSize == ge::DIM_DEFAULT_SIZE + 1) { | |||
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; | |||
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; | |||
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; | |||
} else { | |||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||
axisValue[AXIS_C0] = c0; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||
} | |||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue, | |||
vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||
axisValue[AXIS_C0] = c0; | |||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||
return false); | |||
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; | |||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); | |||
axisValue[AXIS_Co] = c0; | |||
return true; | |||
} | |||
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
vector<int64_t>& axisValue, vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true); | |||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||
axisValue[AXIS_C0] = c0; | |||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"), | |||
return false); | |||
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; | |||
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; | |||
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; | |||
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; | |||
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; | |||
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; | |||
return true; | |||
} | |||
}; // namespace ge |
@@ -0,0 +1,144 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file axis_util.h | |||
* \brief get the axis value | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ | |||
#include <memory.h> | |||
#include <functional> | |||
#include <vector> | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "operator.h" | |||
#include "graph/operator_reg.h" | |||
#include "op_log.h" | |||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||
#define LOG_INFO(format, args...) printf(format, ##args) | |||
namespace ge { | |||
const uint32_t NCHW_DIMENSION_NUM = 4; | |||
const int32_t AXIS_NCHW_DIM_N = 0; | |||
const int32_t AXIS_NCHW_DIM_C = 1; | |||
const int32_t AXIS_NCHW_DIM_H = 2; | |||
const int32_t AXIS_NCHW_DIM_W = 3; | |||
const int32_t AXIS_NHWC_DIM_N = 0; | |||
const int32_t AXIS_NHWC_DIM_H = 1; | |||
const int32_t AXIS_NHWC_DIM_W = 2; | |||
const int32_t AXIS_NHWC_DIM_C = 3; | |||
const int32_t AXIS_NC1HWC0_DIM_N = 0; | |||
const int32_t AXIS_NC1HWC0_DIM_C1 = 1; | |||
const int32_t AXIS_NC1HWC0_DIM_C0 = 4; | |||
const int32_t AXIS_NC1HWC0_DIM_H = 2; | |||
const int32_t AXIS_NC1HWC0_DIM_W = 3; | |||
const int32_t AXIS_HWCN_DIM_H = 0; | |||
const int32_t AXIS_HWCN_DIM_W = 1; | |||
const int32_t AXIS_HWCN_DIM_C = 2; | |||
const int32_t AXIS_HWCN_DIM_N = 3; | |||
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; | |||
const int32_t AXIS_C1HWNCoC0_DIM_H = 1; | |||
const int32_t AXIS_C1HWNCoC0_DIM_W = 2; | |||
const int32_t AXIS_C1HWNCoC0_DIM_N = 3; | |||
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; | |||
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; | |||
#define CHECK_NOTNULL(val) \ | |||
do { \ | |||
if ((val) == nullptr) { \ | |||
LOG_ERROR("[ERROR]Parameter[%s] must not be null.", #val); \ | |||
return false; \ | |||
} \ | |||
} while (0) | |||
#define CHECK(cond, log_func, return_expr) \ | |||
do { \ | |||
if (cond) { \ | |||
log_func; \ | |||
return_expr; \ | |||
} \ | |||
} while (0) | |||
enum AxisValueType { | |||
AXIS_N = 0, | |||
AXIS_C = 1, | |||
AXIS_H = 2, | |||
AXIS_W = 3, | |||
AXIS_C1 = 4, | |||
AXIS_C0 = 5, | |||
AXIS_Co = 6, | |||
AXIS_D = 7, | |||
AXIS_BOTTOM = 8 | |||
}; | |||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor); | |||
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */ | |||
/* The first parameter is old shape's dimension, | |||
* second is c0 and third is axis value. */ | |||
using GetAxisValueInfoByFormat = | |||
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>; | |||
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>; | |||
class AxisUtil { | |||
public: | |||
AxisUtil(); | |||
~AxisUtil(){}; | |||
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
bool HasAxisValueFunc(const ge::Format& format); | |||
private: | |||
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||
* formats. */ | |||
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap; | |||
}; | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_ |
@@ -0,0 +1,417 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file common_shape_fns.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ | |||
#include <string> | |||
#include <vector> | |||
#include "graph/tensor.h" | |||
#include "graph/operator.h" | |||
#include "graph/op_desc.h" | |||
#include "graph/ge_tensor.h" | |||
#include "error_code.h" | |||
namespace ge { | |||
/** | |||
* Check whether Shape's rank is at least rank | |||
* @param tensor Input tensor | |||
* @param rank expect val of Shape | |||
* @param out Output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||
/** | |||
* Check whether Shape's rank is at least rank | |||
* @param tensor Input tensor | |||
* @param rank expect val of Shape | |||
* @param out Output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||
/** | |||
* Check whether Shape's rank is equal to rank | |||
* @param tensor Input tensor | |||
* @param rank expect val of Shape | |||
* @param out Output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||
/** | |||
* Check whether Shape's rank is equal to rank | |||
* @param tensor Input tensor | |||
* @param rank expect val of Shape | |||
* @param out Output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||
/** | |||
* Check whether Shape's rank is equal to rank | |||
* @param tensor Input tensor | |||
* @param rank expect val of Shape | |||
* @param out Output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape); | |||
/** | |||
* Check whether dim is equal to value | |||
* @param dim Input dim | |||
* @param value expect val of dim | |||
* @param out Output dim | |||
* @return status whether Dim is equal to value | |||
*/ | |||
graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name); | |||
/** | |||
* Merge two dims of Shape | |||
* @param dim0 first dim val | |||
* @param dim1 second dim val | |||
* @param out merged dim val | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out); | |||
/** | |||
* Merge two shapes | |||
* @param s0 first shape val | |||
* @param s1 second shape val | |||
* @param out merged shape val | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name); | |||
/** | |||
* Merge two shapes | |||
* @param s0 first Geshape val | |||
* @param s1 second Geshape val | |||
* @param out merged Geshape val | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name); | |||
/** | |||
* Replace one dim in a given shape | |||
* @param s original shape | |||
* @param dim_index_in dim index | |||
* @param new_dim new dim value | |||
* @param out new shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name); | |||
/** | |||
* Replace one dim in a given shape | |||
* @param s original shape | |||
* @param dim_index_in dim index | |||
* @param new_dim new dim value | |||
* @param out new shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name); | |||
/** | |||
* Check if it satisfies 0 <= index < limit | |||
* @param index first input | |||
* @param limit second input | |||
* @return status whether this operation success | |||
*/ | |||
template <typename Ta, typename Tb> | |||
bool FastBoundsCheck(const Ta index, const Tb limit); | |||
/** | |||
* Add two dims | |||
* @param dim0 first dim val | |||
* @param dim1 second dim val | |||
* @param out sum dim val | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out); | |||
/** | |||
* Subtract two dims | |||
* @param dim0 first dim val | |||
* @param dim1 second dim val | |||
* @param out Subtract dim val | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name); | |||
/** | |||
* Get SubShape according to start end index and step size stride | |||
* @param s input Shape | |||
* @param start sub start index | |||
* @param end sub end index | |||
* @param stride sub step size | |||
* @param out sub shape output | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name); | |||
/** | |||
* Get SubShape according to start end index and step size stride | |||
* @param s input Shape | |||
* @param start sub start index | |||
* @param end sub end index | |||
* @param stride sub step size | |||
* @param out sub shape output | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus SubShape(const GeShape& s, size_t start, size_t end, size_t stride, GeShape& out); | |||
/** | |||
* Get SubShape according to start end index and step size stride | |||
* @param s input Shape | |||
* @param start sub start index | |||
* @param end sub end index | |||
* @param stride sub step size | |||
* @param out sub shape output | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus SubShape(const GeShape& s, int64_t start, int64_t end, int64_t stride, GeShape& out, const char* op_name); | |||
/** | |||
* Concatenate two shape | |||
* @param s1 first shape | |||
* @param s2 second shape | |||
* @param out concatenated shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out); | |||
/** | |||
* Concatenate two shape | |||
* @param s1 first shape | |||
* @param s2 second shape | |||
* @param out concatenated shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out); | |||
/** | |||
* Gen matrix shape according d1 and d2 | |||
* @param dim1 first dim val | |||
* @param dim2 first dim val | |||
* @param out matrix shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out); | |||
/** | |||
* Gen vector shape according d | |||
* @param dim dim val | |||
* @param out vector shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Vector(int64_t dim, Shape& out); | |||
/** | |||
* Make shape from shape tensor | |||
* @param tensor shape tensor | |||
* @param out shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name); | |||
/** | |||
* Make shape from shape tensor | |||
* @param op Operator | |||
* @param dst_name const string & | |||
* @param out GeShape | |||
* @param op_name const char * | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name); | |||
/** | |||
* Make dim from scalar tensor | |||
* @param tensor shape tensor | |||
* @param out shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name); | |||
/** | |||
* Check whether Shape's rank is at most rank | |||
* @param tensor input tensor | |||
* @param rank expect val of Shape | |||
* @param out output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name); | |||
/** | |||
* Check whether Shape's rank is at most rank | |||
* @param tensor input tensor | |||
* @param rank expect val of Shape | |||
* @param out output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape); | |||
/** | |||
* make a empty dim shape | |||
* @param out output Shape | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus Scalar(Shape& out); | |||
/** | |||
* set input_name shape to output_name shape | |||
* @param op Operator which need to infershape | |||
* @param input_name input name of Operator | |||
* @param output_name ouput name of Operator | |||
* @return status whether infershape success | |||
*/ | |||
graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name); | |||
/** | |||
* Devide dim | |||
* @param dividend | |||
* @param divisor | |||
* @param evenlyDivisible if to be divisible | |||
* @param out dims | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out, | |||
const char* op_name); | |||
/** | |||
* check shape fully defined or not | |||
* @param shape Shape is checked | |||
* @return whether shape is fully defined | |||
*/ | |||
bool ShapeFullDefined(const Shape& shape); | |||
/** | |||
* check shape fully defined or not | |||
* @param shape Shape is checked | |||
* @return whether shape is fully defined | |||
*/ | |||
bool ShapeFullyDefined(const GeShape& shape); | |||
/** | |||
* check shape known or not | |||
* @param shape Shape is checked | |||
* @return whether rank is known | |||
*/ | |||
bool RankKnown(const Shape& shape); | |||
/** | |||
* check ge_shape known or not | |||
* @param shape GeShape is checked | |||
* @return whether rank is known | |||
*/ | |||
bool RankKnown(const GeShape& shape); | |||
/** | |||
* make a unknown shape with rank | |||
* @return unknown shape | |||
*/ | |||
Shape UnknownShapeOfRank(int64_t rank); | |||
/** | |||
* check dim value known or not | |||
* @param shape which Shape need check dim value | |||
* @param dimIndex the index of dim | |||
* @return whether dim value is known | |||
*/ | |||
bool ValueKnown(const Shape& shape, const size_t& dim_index); | |||
/** | |||
* Validates the 3 component tensors of a sparse tensor | |||
* have the proper shapes. | |||
* @param sparse indices shape | |||
* @param sparse values shape | |||
* @param sparse shape | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape, | |||
const char* op_name); | |||
/** | |||
* DecodeWavShapeFn, infereshape funtion of DecodeWav op | |||
* @param op Operator | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus DecodeWavShapeFn(Operator& op); | |||
/** | |||
* EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||
* @param op Operator | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus EncodeWavShapeFn(Operator& op); | |||
/** | |||
* EncodeWavShapeFn, infereshape funtion of EncodeWav op | |||
* @param op Operator | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus EncodeWavShapeFn(Operator& op); | |||
/** | |||
* Infereshape funtion of SparseSegmentReduction op | |||
* @param op Operator | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus SparseSegmentReductionShapeFn(Operator& op); | |||
/** | |||
* Infereshape funtion of SparseSegmentReductionGrad op | |||
* @param op Operator | |||
* @return status whether Shape's condition Satisfied | |||
*/ | |||
graphStatus SparseSegmentReductionGradShapeFn(Operator& op); | |||
/** | |||
* Validates variable resource handle | |||
* @param op Operator | |||
* @param shape_and_type ShapeAndType vector | |||
* @return status whether this operation success | |||
*/ | |||
graphStatus ValidateVariableResourceHandle(Operator& op, std::vector<ShapeAndType>& shape_and_type); | |||
/** | |||
* Fill op_desc with input shape | |||
* @param op_desc Operator desc ptr | |||
* @param shape input tensor shape | |||
* @param shape input tensor datatype | |||
*/ | |||
void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type = DT_FLOAT); | |||
/** | |||
* InferShapeErrorReport info | |||
* @param op_name Operator name | |||
* @param op_type Operator type | |||
* @param value Operator value | |||
* @param reason error reason | |||
*/ | |||
void InferShapeErrorReport(const std::string& op_name, const std::string& op_type, | |||
const std::string& value, const std::string& reason); | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_ |
@@ -0,0 +1,60 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file error_code.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ | |||
namespace ge { | |||
// error code for report purpose. | |||
// 30000~34999 for aicpu engine error | |||
// and 35000~39999 for infershape error of aicpu op | |||
enum ViewErrorCode { | |||
INVALID_INFER_SHAPE = 14001, | |||
INVALID_INPUT_SHAPE = 35000, | |||
INVALID_ATTR_VALUE = 35001, | |||
INVALID_ATTR_SIZE = 35002, | |||
OTHER_ERROR = 35003, | |||
INVALID_CONV_ATTR_VALUE = 50029, | |||
INVALID_CONV_SET_ATTR = 50057, | |||
INVALID_CONV_SHAPE = 50058, | |||
INVALID_MISS_INPUT = 70001, | |||
INVALID_INPUT_FORMAT = 70002, | |||
INVALID_INPUT_DTYPE = 70003, | |||
INVALID_INPUT_TYPE = 70004, | |||
INVALID_GET_ATTR = 70005, | |||
INVALID_SET_ATTR = 70006, | |||
INVALID_OPS_ATTR_VALUE = 70007, | |||
FAILED_UPDATE_OP = 70008, | |||
INVALID_SHAPE = 70009, | |||
INVALID_SHAPE_SIZE = 70010, | |||
INVALID_SHAPE_DIM = 70011, | |||
INVALID_BROADCAST_SHAPE = 70012, | |||
INVALID_TWO_INPUT_DTYPE = 70013, | |||
INVALID_AIPP_ERROR = 70014, | |||
INVALID_ONE_INPUT_SHAPE = 70015, | |||
INVALID_TWO_INPUT_SHAPE = 70016, | |||
INVALID_ONE_OUTPUT_SHAPE = 70017, | |||
FAILED_GET_COMPILIE_PARAMS = 70018, | |||
}; | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_ |
@@ -0,0 +1,318 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file error_util.cpp | |||
* \brief | |||
*/ | |||
#include <map> | |||
#include "common/util/error_manager/error_manager.h" | |||
#include "error_util.h" | |||
#include "error_code.h" | |||
#include "op_log.h" | |||
using namespace std; | |||
using namespace ge; | |||
namespace ge { | |||
inline static std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode) { | |||
return "E" + std::to_string(errCode); | |||
} | |||
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||
const std::string& correct_shape) { | |||
map<string, string> err_map; | |||
err_map["index"] = std::to_string(index); | |||
err_map["opname"] = opname; | |||
err_map["wrong_shape"] = wrong_shape; | |||
err_map["correct_shape"] = correct_shape; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_SHAPE); | |||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||
const std::string& correct_value) { | |||
map<string, string> err_map; | |||
err_map["attrname"] = attrName; | |||
err_map["opname"] = opname; | |||
err_map["wrong_value"] = wrong_value; | |||
err_map["correct_value"] = correct_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_VALUE); | |||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||
const std::string& correct_size) { | |||
map<string, string> err_map; | |||
err_map["attrname"] = attrName; | |||
err_map["opname"] = opname; | |||
err_map["wrong_size"] = wrong_size; | |||
err_map["correct_size"] = correct_size; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_SIZE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg) { | |||
map<string, string> err_map; | |||
err_map["opname"] = opname; | |||
err_map["err_msg"] = err_msg; | |||
string report_error_code = GetViewErrorCodeStr(ViewErrorCode::OTHER_ERROR); | |||
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_MISS_INPUT); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_format_list, const std::string& data_format) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["expected_format_list"] = expected_format_list; | |||
err_map["format"] = data_format; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_FORMAT); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_data_type_list, const std::string& data_type) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["expected_data_type_list"] = expected_data_type_list; | |||
err_map["data_type"] = data_type; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_DTYPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||
const std::string& actual_type) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["param_type"] = param_type; | |||
err_map["actual_type"] = actual_type; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_TYPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_GET_ATTR); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SET_ATTR); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||
const std::string& input_value) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["excepted_value"] = excepted_value; | |||
err_map["input_value"] = input_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_OPS_ATTR_VALUE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_UPDATE_OP); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||
const std::string& param_value) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["rule_desc"] = rule_desc; | |||
err_map["param_name"] = param_name; | |||
err_map["param_value"] = param_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& error_detail) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["error_detail"] = error_detail; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_INPUT_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||
const std::string& param_name2, const std::string& error_detail) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name1"] = param_name1; | |||
err_map["param_name2"] = param_name2; | |||
err_map["error_detail"] = error_detail; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& error_detail) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["error_detail"] = error_detail; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_OUTPUT_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_GET_COMPILIE_PARAMS); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||
const std::string& real_value) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["input_name"] = input_name; | |||
err_map["max_value"] = max_value; | |||
err_map["real_value"] = real_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_SIZE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||
const std::string& min_value, const std::string& real_value) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["max_value"] = max_value; | |||
err_map["min_value"] = min_value; | |||
err_map["real_value"] = real_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_DIM); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||
const std::string& input2_name, const std::string& input1_shape, | |||
const std::string& input2_shape) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["input1_name"] = input1_name; | |||
err_map["input2_name"] = input2_name; | |||
err_map["input1_shape"] = input1_shape; | |||
err_map["input2_shape"] = input2_shape; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_BROADCAST_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_dtype_list, const std::string& dtype) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["expected_dtype_list"] = expected_dtype_list; | |||
err_map["dtype"] = dtype; | |||
std::string report_error_code = "E50034"; | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||
const std::string& input2_name, const std::string& input1_dtype, | |||
const std::string& input2_dtype) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["input1_name"] = input1_name; | |||
err_map["input2_name"] = input2_name; | |||
err_map["input1_dtype"] = input1_dtype; | |||
err_map["input2_dtype"] = input2_dtype; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_DTYPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||
const std::string& data_W) { | |||
map<string, string> err_map; | |||
err_map["aipp_output_H"] = aipp_output_H; | |||
err_map["aipp_output_W"] = aipp_output_W; | |||
err_map["data_H"] = data_H; | |||
err_map["data_W"] = data_W; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_AIPP_ERROR); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||
const std::string& input_value) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param_name"] = param_name; | |||
err_map["expected_value"] = expected_value; | |||
err_map["input_value"] = input_value; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_ATTR_VALUE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||
const std::string& param2_name) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["param1_name"] = param1_name; | |||
err_map["param2_name"] = param2_name; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SET_ATTR); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description) { | |||
map<string, string> err_map; | |||
err_map["op_name"] = op_name; | |||
err_map["description"] = description; | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SHAPE); | |||
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map); | |||
} | |||
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||
const std::string& reason) { | |||
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INFER_SHAPE); | |||
ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"}, | |||
{op_name, op_type, value, reason}); | |||
} | |||
void CommonRuntimeErrLog(const std::string& opname, const std::string& description){ | |||
map<string, string> err_map; | |||
err_map["op_name"] = opname; | |||
err_map["description"] = description; | |||
OP_LOGE(opname.c_str(), description); | |||
(void)ErrorManager::GetInstance().ReportErrMessage("E50058", err_map); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,184 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file error_util.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ | |||
#include <sstream> | |||
#include <string> | |||
#include <vector> | |||
#include "operator.h" | |||
namespace ge { | |||
/* | |||
* get debug string of vector | |||
* param[in] v vector | |||
* return vector's debug string | |||
*/ | |||
template <typename T> | |||
std::string DebugString(const std::vector<T>& v) { | |||
std::ostringstream oss; | |||
oss << "["; | |||
if (v.size() > 0) { | |||
for (size_t i = 0; i < v.size() - 1; ++i) { | |||
oss << v[i] << ", "; | |||
} | |||
oss << v[v.size() - 1]; | |||
} | |||
oss << "]"; | |||
return oss.str(); | |||
} | |||
/* | |||
* str cat util function | |||
* param[in] params need concat to string | |||
* return concatted string | |||
*/ | |||
template <typename T> | |||
std::string ConcatString(T arg) { | |||
std::ostringstream oss; | |||
oss << arg; | |||
return oss.str(); | |||
} | |||
template <typename T, typename... Ts> | |||
std::string ConcatString(T arg, Ts... arg_left) { | |||
std::ostringstream oss; | |||
oss << arg; | |||
oss << ConcatString(arg_left...); | |||
return oss.str(); | |||
} | |||
/* | |||
* report input shape error of infer shape | |||
* param[in] index the index of input | |||
* param[in] opname op name | |||
* param[in] wrong_shape wrong input shape | |||
* param[in] correct_shape correct input shape | |||
* return void | |||
*/ | |||
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape, | |||
const std::string& correct_shape); | |||
/* | |||
* report attr value error of infer shape | |||
* param[in] attrname the attr name | |||
* param[in] opname op name | |||
* param[in] wrong_value wrong attr value | |||
* param[in] correct_value correct attr value | |||
* return void | |||
*/ | |||
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value, | |||
const std::string& correct_value); | |||
/* | |||
* report attr size error of infer shape | |||
* param[in] attrname the attr name | |||
* param[in] opname op name | |||
* param[in] wrong_size wrong attr size | |||
* param[in] correct_size correct attr size | |||
* return void | |||
*/ | |||
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size, | |||
const std::string& correct_size); | |||
/* | |||
* report common error of infer shape | |||
* param[in] opname op name | |||
* param[in] err_msg error message | |||
* return void | |||
*/ | |||
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg); | |||
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name); | |||
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_format_list, const std::string& data_format); | |||
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_data_type_list, const std::string& data_type); | |||
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type, | |||
const std::string& actual_type); | |||
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name); | |||
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value, | |||
const std::string& input_value); | |||
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name); | |||
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name, | |||
const std::string& param_value); | |||
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& error_detail); | |||
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1, | |||
const std::string& param_name2, const std::string& error_detail); | |||
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& error_detail); | |||
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name); | |||
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value, | |||
const std::string& real_value); | |||
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value, | |||
const std::string& min_value, const std::string& real_value); | |||
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name, | |||
const std::string& input2_name, const std::string& input1_shape, | |||
const std::string& input2_shape); | |||
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name, | |||
const std::string& expected_dtype_list, const std::string& dtype); | |||
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name, | |||
const std::string& input2_name, const std::string& input1_dtype, | |||
const std::string& input2_dtype); | |||
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H, | |||
const std::string& data_W); | |||
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value, | |||
const std::string& input_value); | |||
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name, | |||
const std::string& param2_name); | |||
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description); | |||
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value, | |||
const std::string& reason); | |||
/* | |||
* log common runtime error | |||
* param[in] opname op name | |||
* param[in] error description | |||
* return void | |||
*/ | |||
void CommonRuntimeErrLog(const std::string& opname, const std::string& description); | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_ |
@@ -0,0 +1,73 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file op_common_util.h | |||
* \brief common util for op, in this file only original type or class in C++ allowed | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ | |||
#include <set> | |||
#include <string> | |||
#include <vector> | |||
#include <iostream> | |||
#include <sstream> | |||
template <typename T1, typename T2> | |||
std::ostream& operator<<(std::ostream& os, const std::pair<T1, T2>& values) { | |||
os << "[" << values.first << ", " << values.second << "]"; | |||
return os; | |||
} | |||
template <typename T> | |||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) { | |||
os << "["; | |||
for (const auto& item : values) { | |||
os << item << ", "; | |||
} | |||
os << "]"; | |||
return os; | |||
} | |||
namespace ops { | |||
template<typename T> | |||
std::string to_string(const std::vector<T> &items) { | |||
std::ostringstream oss; | |||
oss << "["; | |||
for (const auto &item: items) { | |||
oss << item << ", "; | |||
} | |||
oss << "]"; | |||
return oss.str(); | |||
} | |||
template<typename T> | |||
std::string to_string(const std::set<T> &items) { | |||
std::ostringstream oss; | |||
oss << "["; | |||
for (const auto &item: items) { | |||
oss << item << ", "; | |||
} | |||
oss << "]"; | |||
return oss.str(); | |||
} | |||
} // namespace ops | |||
#endif //OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_ |
@@ -0,0 +1,89 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file op_log.h | |||
* \brief | |||
*/ | |||
#ifndef GE_OP_LOG_H | |||
#define GE_OP_LOG_H | |||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||
#include "toolchain/slog.h" | |||
#else | |||
#include <utils/Log.h> | |||
#endif | |||
#define OPPROTO_SUBMOD_NAME "OP_PROTO" | |||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||
#define OP_LOGI(opname, ...) D_OP_LOGI(opname, __VA_ARGS__) | |||
#define OP_LOGW(opname, ...) D_OP_LOGW(opname, __VA_ARGS__) | |||
#define OP_LOGE(opname, ...) D_OP_LOGE(opname, __VA_ARGS__) | |||
#define OP_LOGD(opname, ...) D_OP_LOGD(opname, __VA_ARGS__) | |||
#define GE_OP_LOGI(opname, ...) GE_D_OP_LOGI(opname, __VA_ARGS__) | |||
#define GE_OP_LOGW(opname, ...) GE_D_OP_LOGW(opname, __VA_ARGS__) | |||
#define GE_OP_LOGE(opname, ...) GE_D_OP_LOGE(opname, __VA_ARGS__) | |||
#define GE_OP_LOGD(opname, ...) GE_D_OP_LOGD(opname, __VA_ARGS__) | |||
#define FUSION_PASS_LOGI(...) D_FUSION_PASS_LOGI(__VA_ARGS__) | |||
#define FUSION_PASS_LOGW(...) D_FUSION_PASS_LOGW(__VA_ARGS__) | |||
#define FUSION_PASS_LOGE(...) D_FUSION_PASS_LOGE(__VA_ARGS__) | |||
#define FUSION_PASS_LOGD(...) D_FUSION_PASS_LOGD(__VA_ARGS__) | |||
#else | |||
#define OP_LOGI(opname, ...) | |||
#define OP_LOGW(opname, ...) | |||
#define OP_LOGE(opname, ...) | |||
#define OP_LOGD(opname, ...) | |||
#define FUSION_PASS_LOGI(...) | |||
#define FUSION_PASS_LOGW(...) | |||
#define FUSION_PASS_LOGE(...) | |||
#define FUSION_PASS_LOGD(...) | |||
#endif | |||
#if !defined( __ANDROID__) && !defined(ANDROID) | |||
#define D_OP_LOGI(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define D_OP_LOGW(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define D_OP_LOGE(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define D_OP_LOGD(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define GE_D_OP_LOGI(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define GE_D_OP_LOGW(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define GE_D_OP_LOGE(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define GE_D_OP_LOGD(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__) | |||
#define D_FUSION_PASS_LOGI(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
#define D_FUSION_PASS_LOGW(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
#define D_FUSION_PASS_LOGE(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
#define D_FUSION_PASS_LOGD(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
#else | |||
#define D_OP_LOGI(opname, fmt, ...) | |||
#define D_OP_LOGW(opname, fmt, ...) | |||
#define D_OP_LOGE(opname, fmt, ...) | |||
#define D_OP_LOGD(opname, fmt, ...) | |||
#define D_FUSION_PASS_LOGI(fmt, ...) | |||
#define D_FUSION_PASS_LOGW(fmt, ...) | |||
#define D_FUSION_PASS_LOGE(fmt, ...) | |||
#define D_FUSION_PASS_LOGD(fmt, ...) | |||
#endif | |||
#define OP_CHECK(condition, log_func, do_expr) \ | |||
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \ | |||
do { \ | |||
if (condition) { \ | |||
log_func; \ | |||
do_expr; \ | |||
} \ | |||
} while (0) | |||
#endif //GE_OP_LOG_H |
@@ -0,0 +1,258 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file transfer_shape_according_to_format.cpp | |||
* \brief set shape according to original format and current format | |||
*/ | |||
#include "transfer_shape_according_to_format.h" | |||
#include "framework/omg/omg_inner_types.h" | |||
namespace ge { | |||
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { | |||
getNewShapeFuncMap = { | |||
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, | |||
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, | |||
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, | |||
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, | |||
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, | |||
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, | |||
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; | |||
mapOfDtypeAndC0 = { | |||
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, | |||
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, | |||
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, | |||
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_C]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newShape = ge::GeShape(newDimVec); | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newDimVec.push_back(axisValue[AXIS_C]); | |||
newShape = ge::GeShape(newDimVec); | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_C1]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newDimVec.push_back(axisValue[AXIS_C0]); | |||
newShape = ge::GeShape(newDimVec); | |||
} else { | |||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_C]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newShape = ge::GeShape(newDimVec); | |||
} | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
if (ndValue.size() == SIZE_OF_CN) { | |||
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true); | |||
auto sizeOfOriginalVec = ndValue.size(); | |||
std::vector<int64_t> newDimVec = ndValue; | |||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); | |||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); | |||
newDimVec.push_back(SHAPE_NUMBER_16); | |||
newDimVec.push_back(axisValue[AXIS_C0]); | |||
newShape = ge::GeShape(newDimVec); | |||
} else { | |||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||
CHECK(axisValue.size() <= AXIS_C1, LOG_INFO("AxisValue is not correct!"), return true); | |||
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; | |||
newDimVec.push_back(hwc1); | |||
newDimVec.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); | |||
newDimVec.push_back(NI); | |||
newDimVec.push_back(axisValue[AXIS_C0]); | |||
newShape = ge::GeShape(newDimVec); | |||
} else { | |||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_C]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newShape = ge::GeShape(newDimVec); | |||
} | |||
} | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newDimVec.push_back(axisValue[AXIS_C]); | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newShape = ge::GeShape(newDimVec); | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(axisValue.size() <= AXIS_Co, LOG_INFO("AxisValue is not correct!"), return true); | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec; | |||
newDimVec.push_back(axisValue[AXIS_C1]); | |||
newDimVec.push_back(axisValue[AXIS_H]); | |||
newDimVec.push_back(axisValue[AXIS_W]); | |||
newDimVec.push_back(axisValue[AXIS_N]); | |||
newDimVec.push_back(axisValue[AXIS_Co]); | |||
newDimVec.push_back(axisValue[AXIS_C0]); | |||
newShape = ge::GeShape(newDimVec); | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue) { | |||
CHECK(ndValue.empty(), LOG_INFO("ndValue is empty!"), return true); | |||
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, | |||
LOG_INFO("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); | |||
uint32_t sizeOfOriginalVec = ndValue.size(); | |||
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { | |||
LOG_INFO("ndValue's dim num is less than 2!"); | |||
return true; | |||
} | |||
/* axisValue is initialized as a size 6 vector. */ | |||
std::vector<int64_t> newDimVec = ndValue; | |||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); | |||
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); | |||
newDimVec.push_back(SHAPE_NUMBER_16); | |||
newDimVec.push_back(axisValue[AXIS_C0]); | |||
newShape = ge::GeShape(newDimVec); | |||
return true; | |||
} | |||
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { | |||
/* The default new shape is old shape */ | |||
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; | |||
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { | |||
LOG_ERROR("Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||
return false; | |||
} | |||
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { | |||
LOG_ERROR("currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); | |||
return false; | |||
} | |||
AxisUtil* axisutil_object = new AxisUtil(); | |||
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { | |||
delete axisutil_object; | |||
return true; | |||
} | |||
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); | |||
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { | |||
LOG_INFO("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); | |||
delete axisutil_object; | |||
return true; | |||
} | |||
LOG_INFO("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; | |||
CHECK_NOTNULL(getNewShapeFunc); | |||
std::vector<int64_t> axisValue; | |||
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { | |||
axisValue.push_back(1); | |||
} | |||
std::vector<int64_t> ndValue; | |||
uint32_t c0; | |||
if (mapOfDtypeAndC0.empty()) { | |||
c0 = SHAPE_NUMBER_16; | |||
} else { | |||
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); | |||
if (iterGetC0 == mapOfDtypeAndC0.end()) { | |||
LOG_ERROR("Dtype is not support."); | |||
delete axisutil_object; | |||
return true; | |||
} | |||
c0 = iterGetC0->second; | |||
} | |||
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 | |||
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { | |||
c0 = SHAPE_DIM_VALUE_C04; | |||
} | |||
bool status = axisutil_object->GetAxisValueByOriginFormat( | |||
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape.GetDims(), c0, axisValue, ndValue); | |||
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { | |||
delete axisutil_object; | |||
return true; | |||
} | |||
delete axisutil_object; | |||
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); | |||
if (c != nullptr) { | |||
*c = axisValue[AXIS_C]; | |||
} | |||
return true; | |||
} | |||
}; // namespace ge |
@@ -0,0 +1,129 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file transfer_shape_according_to_format.h | |||
* \brief set shape according to original format and current format | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||
#include "axis_util.h" | |||
#include <memory.h> | |||
#include <functional> | |||
#include <vector> | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "operator.h" | |||
#include "graph/operator_reg.h" | |||
#include "graph/tensor.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "op_log.h" | |||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||
#define LOG_INFO(format, args...) printf(format, ##args) | |||
namespace ge { | |||
enum OpImplType { | |||
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | |||
EN_IMPL_CUSTOM_TIK, // custom tik op | |||
EN_IMPL_CUSTOM_TBE, // custom tbe op | |||
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op | |||
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op | |||
EN_IMPL_HW_TIK, // Huawei built-in tik op | |||
EN_IMPL_HW_TBE, // Huawei built-in tbe op | |||
EN_IMPL_RL, // RL op | |||
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op | |||
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op | |||
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op | |||
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op | |||
EN_RESERVED // reserved value | |||
}; | |||
const uint32_t SHAPE_NUMBER_16 = 16; | |||
const uint32_t SHAPE_NUMBER_32 = 32; | |||
const uint32_t SHAPE_DIM_VALUE_C04 = 4; | |||
const uint32_t NI = 16; | |||
const uint32_t MINUS_VALUE_ONE = 1; | |||
const uint32_t MINUS_VALUE_TWO = 2; | |||
const uint32_t SIZE_OF_CN = 2; | |||
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; | |||
/* The first parameter is axis value, second is new shape and third is | |||
* op implementation type. */ | |||
using GetNewShapeByAxisValueAndFormat = | |||
std::function<bool(ge::GeShape&, const int64_t&, vector<int64_t>&, vector<int64_t>&)>; | |||
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>; | |||
struct ShapeAndFormatInfo { | |||
const ge::GeShape& oldShape; | |||
ge::GeShape& newShape; | |||
const ge::Format& oldFormat; | |||
const ge::Format& newFormat; | |||
const ge::DataType& currentDataType; | |||
const int64_t& opImplType; | |||
}; | |||
using ShapeAndFormat = struct ShapeAndFormatInfo; | |||
class ShapeTransferAccordingToFormat { | |||
public: | |||
ShapeTransferAccordingToFormat(); | |||
~ShapeTransferAccordingToFormat(){}; | |||
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; | |||
ShapeTransferAccordingToFormat& operator=(const ShapeTransferAccordingToFormat&) = delete; | |||
bool GetShapeAccordingToFormat(ShapeAndFormat& inputAndOutputInfo, int64_t* c = nullptr); | |||
/* ----------Below is the function of getting new shape---------------------- */ | |||
static bool GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue); | |||
static bool GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue); | |||
static bool GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||
static bool GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue); | |||
static bool GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue); | |||
static bool GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, | |||
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue); | |||
static bool GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue, | |||
const vector<int64_t>& ndValue); | |||
private: | |||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||
* formats. */ | |||
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap; | |||
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0; | |||
}; | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ |
@@ -0,0 +1,363 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
/*! | |||
* \file util.h | |||
* \brief | |||
*/ | |||
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||
#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ | |||
#include <memory.h> | |||
#include <string> | |||
#include <vector> | |||
#include <map> | |||
#include <algorithm> | |||
#include "framework/omg/omg_inner_types.h" | |||
#include "operator.h" | |||
#include "graph/operator_reg.h" | |||
#include "graph/operator_reg.h" | |||
#include "transfer_shape_according_to_format.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "graph/utils/tensor_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/tensor.h" | |||
#include "graph/node.h" | |||
#include "graph/ge_tensor.h" | |||
#include "op_log.h" | |||
#define LOG_ERROR(format, args...) printf(format, ##args) | |||
namespace ge { | |||
// enum type and string type mapping | |||
static const std::map<ge::DataType, std::string> DTYPE_STR_MAP{ | |||
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"}, | |||
{ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"}, | |||
{ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}}; | |||
// define the input num of shape | |||
const size_t INPUT_NUM0 = 0; | |||
const size_t INPUT_NUM1 = 1; | |||
const size_t INPUT_NUM2 = 2; | |||
const size_t INPUT_NUM3 = 3; | |||
const size_t INPUT_NUM4 = 4; | |||
const size_t INPUT_NUM5 = 5; | |||
const size_t INPUT_NUM6 = 6; | |||
const size_t INPUT_NUM7 = 7; | |||
const size_t INPUT_NUM8 = 8; | |||
const size_t INPUT_NUM9 = 9; | |||
// define the dims size of shape | |||
const size_t DIM_SIZE0 = 0; | |||
const size_t DIM_SIZE1 = 1; | |||
const size_t DIM_SIZE2 = 2; | |||
const size_t DIM_SIZE3 = 3; | |||
const size_t DIM_SIZE4 = 4; | |||
const size_t DIM_SIZE5 = 5; | |||
const size_t DIM_SIZE6 = 6; | |||
const size_t DIM_SIZE7 = 7; | |||
const size_t DIM_SIZE8 = 8; | |||
// define the index of shape dim | |||
const size_t DIM_INDEX0 = 0; | |||
const size_t DIM_INDEX1 = 1; | |||
const size_t DIM_INDEX2 = 2; | |||
const size_t DIM_INDEX3 = 3; | |||
const size_t DIM_INDEX4 = 4; | |||
const size_t DIM_INDEX5 = 5; | |||
const size_t DIM_INDEX6 = 6; | |||
const size_t DIM_INDEX7 = 7; | |||
const size_t DIM_INDEX8 = 8; | |||
/* | |||
* get the datatype of input | |||
* param[in] dataType input datatype of enum value | |||
* param[in] supportList the support range of op | |||
* return true :get type success | |||
* false:get type failed | |||
*/ | |||
bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList); | |||
bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType); | |||
/* infer shape of two input and on output with broadcast | |||
* param[in] op op desc supply by ge | |||
* param[in] inputName1 first input name | |||
* param[in] inputName2 second input name | |||
* param[in] outputName output name | |||
* return SUCCESS:infer success | |||
* FAILED:infer failed like unsupported broadcast input shape | |||
*/ | |||
bool CheckInputDataType(const Operator& op, const std::string& input_name, | |||
const std::vector<ge::DataType>& support_list); | |||
/* | |||
* check the datatype and shape of input | |||
* param[in] op the operator | |||
* param[in] inputTensorMap the map of input name and support datatype | |||
* param[in] paramType the mode of input param, tensor or scalar | |||
* return true | |||
* false | |||
*/ | |||
bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap); | |||
/* | |||
* infer shape of two input and on output with broadcast | |||
* param[in] op op desc supply by ge | |||
* param[in] inputName1 first input name | |||
* param[in] inputName2 second input name | |||
* param[in] outputName output name | |||
* return SUCCESS:infer success | |||
* FAILED:infer failed like unsupported broadcast input shape | |||
*/ | |||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||
const string& output_name); | |||
/* | |||
* infer shape of two input and on output with broadcast | |||
* param[in] op op desc supply by ge | |||
* param[in] inputName1 first input name | |||
* param[in] inputName2 second input name | |||
* param[in] outputName output name | |||
* param[in] is_dynamic whether the shape of output is dynamic shape | |||
* return SUCCESS:infer success | |||
* FAILED:infer failed like unsupported broadcast input shape | |||
*/ | |||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||
const string& output_name, bool& is_dynamic); | |||
bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2, | |||
const string& output_name); | |||
bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name, | |||
const std::vector<ge::DataType>& supportList); | |||
bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2); | |||
bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors); | |||
bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names); | |||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value); | |||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value); | |||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value); | |||
bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value); | |||
/** | |||
* Get int type const value from tensor data | |||
* @param [in] data const tensor data | |||
* @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64 | |||
* @param [out] const_values const int values | |||
* @return true:success, false:failed. | |||
*/ | |||
bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values); | |||
bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, | |||
std::vector<int64_t>& const_data); | |||
bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype, | |||
std::vector<int64_t>& const_data); | |||
bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data); | |||
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2, | |||
const string& output_name); | |||
/* | |||
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||
* param[in] op op desc supply by ge | |||
* param[in] inputNumBeg input index begin, [0, N] | |||
* param[in] inputNumEnd input index end need to be checked | |||
* param[in] supportList, support type of ge::DataType and ge::Format | |||
* return true: check pass | |||
* false: check failed | |||
*/ | |||
template <typename T> | |||
bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd, | |||
const std::vector<T>& supportList) { | |||
for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) { | |||
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||
ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||
if (findDtype == supportList.end()) { | |||
return false; | |||
} | |||
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||
ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||
if (findDtype == supportList.end()) { | |||
return false; | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
/* | |||
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd | |||
* param[in] op op desc supply by ge | |||
* param[in] indexNeedCheck input index need to be checked | |||
* param[in] supportList, support type of ge::DataType and ge::Format | |||
* return true: check pass | |||
* false: check failed | |||
*/ | |||
template <typename T> | |||
bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector<std::size_t>& indexNeedCheck, | |||
const std::vector<T>& supportList) { | |||
for (auto i : indexNeedCheck) { | |||
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) { | |||
ge::DataType inType = op.GetInputDesc(i).GetDataType(); | |||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||
if (findDtype == supportList.end()) { | |||
return false; | |||
} | |||
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) { | |||
ge::Format inType = op.GetInputDesc(i).GetFormat(); | |||
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType); | |||
if (findDtype == supportList.end()) { | |||
return false; | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
/* | |||
* get const attr | |||
* param[in] op op desc supply by ge | |||
* param[in] attrName list need to be get | |||
* param[out] attr vector | |||
* return true: get success | |||
* false: get failed | |||
*/ | |||
template <typename T> | |||
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, std::vector<T>& attrVec) { | |||
T value; | |||
for (auto name : attrNameList) { | |||
if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) { | |||
return false; | |||
} | |||
attrVec.push_back(value); | |||
} | |||
return true; | |||
} | |||
/* | |||
* get const attr list | |||
* param[in] op op desc supply by ge | |||
* param[in] attrName list need to be get | |||
* param[out] attr vector | |||
* return true: get success | |||
* false: get failed | |||
*/ | |||
template <typename T> | |||
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, | |||
std::vector<std::vector<T>>& attrListVec) { | |||
for (auto name : attrNameList) { | |||
std::vector<T> valueList; | |||
if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) { | |||
return false; | |||
} | |||
attrListVec.push_back(valueList); | |||
} | |||
return true; | |||
} | |||
std::string to_string(const vector<int64_t>& shape); | |||
std::string to_string(const ge::Shape& shape); | |||
std::string to_string(const ge::GeShape& shape); | |||
std::string to_string(const vector<pair<int64_t, int64_t>>& ranges); | |||
class DynamicShapeInfer { | |||
public: | |||
std::map<std::string, Format> map_format; | |||
std::map<std::string, DataType> map_dtype; | |||
std::map<std::string, uint32_t> inputs; | |||
std::map<std::string, uint32_t> outputs; | |||
Operator& op; | |||
OpDescPtr& op_desc; | |||
DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) { | |||
} | |||
bool CatchFormatAndShape(); | |||
bool UpdateFormatAndShape(); | |||
~DynamicShapeInfer() { | |||
UpdateFormatAndShape(); | |||
} | |||
}; | |||
#define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\ | |||
do { \ | |||
if (!depends_names.empty()) { \ | |||
op_desc->SetOpInferDepends(depends_names); \ | |||
} \ | |||
} while(0) | |||
bool IsEmptyTensor(const std::vector<int64_t>& dims); | |||
bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||
bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec); | |||
bool IsUnKnownShape(const std::vector<int64_t>& shape_vec); | |||
bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input"); | |||
bool IsUnknownVec(std::vector<int64_t>& shape_vec); | |||
bool IsUnknown(const std::vector<int64_t>& shape_vec); | |||
void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range); | |||
std::string DataTypeToStringDesc(const ge::DataType& dataType); | |||
bool OneInOneOutDynamicInfer(const Operator& op, | |||
const std::string& input_name, | |||
const std::vector<std::string>& output_name_list); | |||
bool TwoInOneOutDynamicInferNoBroadcast(Operator& op, | |||
const string& input1_name, | |||
const string& input2_name, | |||
const std::vector<string>& output_name_list); | |||
void FixShapeRangeWithDims(const std::vector<int64_t>& dims, | |||
std::vector<int64_t>& shape_1, | |||
std::vector<int64_t>& shape_2, | |||
std::vector<std::pair<int64_t, int64_t>>& range_1, | |||
std::vector<std::pair<int64_t, int64_t>>& range_2); | |||
bool SetScalarOutputDesc(const string& input, | |||
const string& output, | |||
OpDescPtr op_desc, | |||
GeShape& output_shape); | |||
namespace array_ops { | |||
bool CheckInt64MulOverflow(int64_t a, int64_t b); | |||
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||
int64_t& range_max); | |||
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range, | |||
std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape); | |||
} | |||
} // namespace ge | |||
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_ |
@@ -0,0 +1,17 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph_assertion.h" |
@@ -0,0 +1,34 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||
#define GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H | |||
/* | |||
* Compare graph node size, node_attr | |||
*/ | |||
#define ASSERT_GRAPH_EQUAL(g1,g2) \ | |||
do { \ | |||
} while (0) | |||
#define ASSERT_GRAPH_CORRECT(g) \ | |||
do { \ | |||
} while (0) | |||
#define ASSERT_GRAPH_SHAPE_CONTINOUS(g) \ | |||
do { \ | |||
} while (0) | |||
#endif // GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H |
@@ -0,0 +1,48 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph_builder_utils.h" | |||
#include "inc/external/graph/operator.h" | |||
#include "inc/external/graph/operator_factory.h" | |||
#include "graph/utils/graph_utils.h" | |||
namespace ge { | |||
namespace st { | |||
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, | |||
DataType data_type, std::vector<int64_t> shape) { | |||
auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
tensor_desc->SetShape(GeShape(std::move(shape))); | |||
tensor_desc->SetFormat(format); | |||
tensor_desc->SetDataType(data_type); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
for (int i = 0; i < in_cnt; ++i) { | |||
op_desc->AddInputDesc(tensor_desc->Clone()); | |||
} | |||
for (int i = 0; i < out_cnt; ++i) { | |||
op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
} | |||
return graph_->AddNode(op_desc); | |||
} | |||
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||
} | |||
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||
} | |||
} // namespace st | |||
} // namespace ge |
@@ -0,0 +1,53 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
#include <string> | |||
#include <vector> | |||
#include "graph/compute_graph.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/graph.h" | |||
#include "graph/node.h" | |||
namespace ge { | |||
namespace st { | |||
class ComputeGraphBuilder { | |||
public: | |||
explicit ComputeGraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); } | |||
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||
ComputeGraphPtr GetComputeGraph() { | |||
graph_->TopologicalSorting(); | |||
return graph_; | |||
} | |||
Graph GetGraph() { | |||
graph_->TopologicalSorting(); | |||
return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||
} | |||
private: | |||
ComputeGraphPtr graph_; | |||
}; | |||
} // namespace st | |||
} // namespace ge | |||
#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H |
@@ -0,0 +1,17 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "tensor_builder_utils.h" |
@@ -0,0 +1,22 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
class tensor_builder_utils {}; | |||
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H |
@@ -0,0 +1,15 @@ | |||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||
add_executable(graph_engine_test ${SOURCES}) | |||
target_include_directories(graph_engine_test | |||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} | |||
) | |||
set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 11) | |||
target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) | |||
include(CTest) | |||
enable_testing() | |||
add_test(NAME test COMMAND graph_engine_test) |
@@ -0,0 +1,58 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <map> | |||
#include "external/ge/ge_api.h" | |||
#include "framework/common/types.h" | |||
#include "framework.h" | |||
#include "framework/utils/builder/graph_builder_utils.h" | |||
using namespace std; | |||
using namespace ge; | |||
class FrameworkTest : public testing::Test { | |||
protected: | |||
void SetUp() { | |||
// ge initialize | |||
map<AscendString, AscendString> options; | |||
auto ret = ge::GEInitialize(options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
void TearDown() {} | |||
}; | |||
TEST_F(FrameworkTest, test_framework_dummy) { | |||
// build graph | |||
st::ComputeGraphBuilder graphBuilder("g1"); | |||
auto data1 = graphBuilder.AddNode("data1",DATA,1,1); | |||
auto data2 = graphBuilder.AddNode("data2",DATA,1,1); | |||
auto add = graphBuilder.AddNode("add",ADD,2,1); | |||
graphBuilder.AddDataEdge(data1, 0, add,0); | |||
graphBuilder.AddDataEdge(data2, 0, add,1); | |||
Graph graph = graphBuilder.GetGraph(); | |||
// new session & add graph | |||
map<AscendString, AscendString> options; | |||
Session session(options); | |||
auto ret = session.AddGraph(1, graph, options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
// build input tensor | |||
std::vector<InputTensorInfo> inputs; | |||
// build_graph through session | |||
ret = session.BuildGraph(1, inputs); | |||
// TODO check result | |||
} |