Browse Source

modified: CMakeLists.txt

modified:   build.sh
	modified:   ge/ge_runtime/runtime_model.cc
	modified:   ge/ge_runtime/task/aicpu_task.cc
	modified:   ge/ge_runtime/task/hccl_task.cc
	modified:   ge/ge_runtime/task/label_goto_task.cc
	modified:   ge/ge_runtime/task/label_switch_task.cc
	new file:   tests/st/CMakeLists.txt
	new file:   tests/st/cmake/graphengine.cmake
	new file:   tests/st/framework/CMakeLists.txt
	new file:   tests/st/framework/framework.cc
	new file:   tests/st/framework/framework.h
	new file:   tests/st/framework/stub_engine/CMakeLists.txt
	new file:   tests/st/framework/stub_engine/common/constant/constant.h
	new file:   tests/st/framework/stub_engine/engine/stub_engine.cc
	new file:   tests/st/framework/stub_engine/engine/stub_engine.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h
	new file:   tests/st/framework/stub_engine/proto/task.proto
	new file:   tests/st/framework/stub_op_proto/array_ops.cc
	new file:   tests/st/framework/stub_op_proto/array_ops.h
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.cc
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.h
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.h
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/axis_util.cc
	new file:   tests/st/framework/stub_op_proto/util/axis_util.h
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/error_code.h
	new file:   tests/st/framework/stub_op_proto/util/error_util.cc
	new file:   tests/st/framework/stub_op_proto/util/error_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_common_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_log.h
new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h
	new file:   tests/st/framework/stub_op_proto/util/util.cc
	new file:   tests/st/framework/stub_op_proto/util/util.h
	new file:   tests/st/framework/utils/assertion/graph_assertion.cc
	new file:   tests/st/framework/utils/assertion/graph_assertion.h
	new file:   tests/st/framework/utils/builder/graph_builder_utils.cc
	new file:   tests/st/framework/utils/builder/graph_builder_utils.h
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.cc
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.h
	new file:   tests/st/test.cc
	new file:   tests/st/testcase/CMakeLists.txt
	new file:   tests/st/testcase/test_framework_dummy.cc

	modified:   CMakeLists.txt
	modified:   build.sh
	modified:   ge/ge_runtime/runtime_model.cc
	modified:   ge/ge_runtime/task/aicpu_task.cc
	modified:   ge/ge_runtime/task/hccl_task.cc
	modified:   ge/ge_runtime/task/label_goto_task.cc
	modified:   ge/ge_runtime/task/label_switch_task.cc
	new file:   tests/st/CMakeLists.txt
	new file:   tests/st/cmake/graphengine.cmake
	new file:   tests/st/framework/CMakeLists.txt
	new file:   tests/st/framework/framework.cc
	new file:   tests/st/framework/framework.h
	new file:   tests/st/framework/stub_engine/CMakeLists.txt
	new file:   tests/st/framework/stub_engine/common/constant/constant.h
	new file:   tests/st/framework/stub_engine/engine/stub_engine.cc
	new file:   tests/st/framework/stub_engine/engine/stub_engine.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h
	new file:   tests/st/framework/stub_engine/proto/task.proto
	new file:   tests/st/framework/stub_op_proto/array_ops.cc
	new file:   tests/st/framework/stub_op_proto/array_ops.h
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.cc
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.h
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.h
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/axis_util.cc
	new file:   tests/st/framework/stub_op_proto/util/axis_util.h
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/error_code.h
	new file:   tests/st/framework/stub_op_proto/util/error_util.cc
	new file:   tests/st/framework/stub_op_proto/util/error_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_common_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_log.h
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h
	new file:   tests/st/framework/stub_op_proto/util/util.cc
	new file:   tests/st/framework/stub_op_proto/util/util.h
	new file:   tests/st/framework/utils/assertion/graph_assertion.cc
	new file:   tests/st/framework/utils/assertion/graph_assertion.h
	new file:   tests/st/framework/utils/builder/graph_builder_utils.cc
	new file:   tests/st/framework/utils/builder/graph_builder_utils.h
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.cc
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.h
	new file:   tests/st/test.cc
	new file:   tests/st/testcase/CMakeLists.txt
	new file:   tests/st/testcase/test_framework_dummy.cc

	modified:   CMakeLists.txt
	modified:   build.sh
	modified:   ge/ge_runtime/runtime_model.cc
	modified:   ge/ge_runtime/task/aicpu_task.cc
	modified:   ge/ge_runtime/task/hccl_task.cc
	modified:   ge/ge_runtime/task/label_goto_task.cc
	modified:   ge/ge_runtime/task/label_switch_task.cc
	new file:   tests/st/CMakeLists.txt
	new file:   tests/st/cmake/graphengine.cmake
	new file:   tests/st/framework/CMakeLists.txt
	new file:   tests/st/framework/framework.cc
	new file:   tests/st/framework/framework.h
	new file:   tests/st/framework/stub_engine/CMakeLists.txt
	new file:   tests/st/framework/stub_engine/common/constant/constant.h
	new file:   tests/st/framework/stub_engine/engine/stub_engine.cc
	new file:   tests/st/framework/stub_engine/engine/stub_engine.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h
	new file:   tests/st/framework/stub_engine/proto/task.proto
	new file:   tests/st/framework/stub_op_proto/array_ops.cc
	new file:   tests/st/framework/stub_op_proto/array_ops.h
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.cc
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.h
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.h
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/axis_util.cc
	new file:   tests/st/framework/stub_op_proto/util/axis_util.h
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.h
new file:   tests/st/framework/stub_op_proto/util/error_code.h
	new file:   tests/st/framework/stub_op_proto/util/error_util.cc
	new file:   tests/st/framework/stub_op_proto/util/error_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_common_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_log.h
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h
	new file:   tests/st/framework/stub_op_proto/util/util.cc
	new file:   tests/st/framework/stub_op_proto/util/util.h
	new file:   tests/st/framework/utils/assertion/graph_assertion.cc
	new file:   tests/st/framework/utils/assertion/graph_assertion.h
	new file:   tests/st/framework/utils/builder/graph_builder_utils.cc
	new file:   tests/st/framework/utils/builder/graph_builder_utils.h
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.cc
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.h
	new file:   tests/st/test.cc
	new file:   tests/st/testcase/CMakeLists.txt
	new file:   tests/st/testcase/test_framework_dummy.cc

	modified:   CMakeLists.txt
	modified:   build.sh
	modified:   ge/ge_runtime/runtime_model.cc
	modified:   ge/ge_runtime/task/aicpu_task.cc
	modified:   ge/ge_runtime/task/hccl_task.cc
	modified:   ge/ge_runtime/task/label_goto_task.cc
	modified:   ge/ge_runtime/task/label_switch_task.cc
	new file:   tests/st/CMakeLists.txt
	new file:   tests/st/cmake/graphengine.cmake
	new file:   tests/st/framework/CMakeLists.txt
	new file:   tests/st/framework/framework.cc
	new file:   tests/st/framework/framework.h
	new file:   tests/st/framework/stub_engine/CMakeLists.txt
	new file:   tests/st/framework/stub_engine/common/constant/constant.h
	new file:   tests/st/framework/stub_engine/engine/stub_engine.cc
	new file:   tests/st/framework/stub_engine/engine/stub_engine.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op.h
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc
	new file:   tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h
	new file:   tests/st/framework/stub_engine/proto/task.proto
	new file:   tests/st/framework/stub_op_proto/array_ops.cc
	new file:   tests/st/framework/stub_op_proto/array_ops.h
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.cc
	new file:   tests/st/framework/stub_op_proto/control_flow_ops.h
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
	new file:   tests/st/framework/stub_op_proto/elewise_calculation_ops.h
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/axis_util.cc
	new file:   tests/st/framework/stub_op_proto/util/axis_util.h
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.cc
	new file:   tests/st/framework/stub_op_proto/util/common_shape_fns.h
	new file:   tests/st/framework/stub_op_proto/util/error_code.h
	new file:   tests/st/framework/stub_op_proto/util/error_util.cc
	new file:   tests/st/framework/stub_op_proto/util/error_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_common_util.h
	new file:   tests/st/framework/stub_op_proto/util/op_log.h
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc
	new file:   tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h
	new file:   tests/st/framework/stub_op_proto/util/util.cc
	new file:   tests/st/framework/stub_op_proto/util/util.h
	new file:   tests/st/framework/utils/assertion/graph_assertion.cc
	new file:   tests/st/framework/utils/assertion/graph_assertion.h
	new file:   tests/st/framework/utils/builder/graph_builder_utils.cc
	new file:   tests/st/framework/utils/builder/graph_builder_utils.h
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.cc
	new file:   tests/st/framework/utils/builder/tensor_builder_utils.h
	new file:   tests/st/testcase/CMakeLists.txt
	new file:   tests/st/testcase/test_framework_dummy.cc
pull/1696/head
zhaoxinxin 4 years ago
parent
commit
5c750650e1
55 changed files with 18337 additions and 113 deletions
  1. +120
    -106
      CMakeLists.txt
  2. +25
    -1
      build.sh
  3. +2
    -1
      ge/ge_runtime/runtime_model.cc
  4. +1
    -1
      ge/ge_runtime/task/aicpu_task.cc
  5. +2
    -2
      ge/ge_runtime/task/hccl_task.cc
  6. +1
    -1
      ge/ge_runtime/task/label_goto_task.cc
  7. +1
    -1
      ge/ge_runtime/task/label_switch_task.cc
  8. +6
    -0
      tests/st/CMakeLists.txt
  9. +249
    -0
      tests/st/cmake/graphengine.cmake
  10. +16
    -0
      tests/st/framework/CMakeLists.txt
  11. +26
    -0
      tests/st/framework/framework.cc
  12. +33
    -0
      tests/st/framework/framework.h
  13. +259
    -0
      tests/st/framework/stub_engine/CMakeLists.txt
  14. +30
    -0
      tests/st/framework/stub_engine/common/constant/constant.h
  15. +74
    -0
      tests/st/framework/stub_engine/engine/stub_engine.cc
  16. +126
    -0
      tests/st/framework/stub_engine/engine/stub_engine.h
  17. +114
    -0
      tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc
  18. +51
    -0
      tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h
  19. +67
    -0
      tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc
  20. +86
    -0
      tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h
  21. +40
    -0
      tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc
  22. +36
    -0
      tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h
  23. +45
    -0
      tests/st/framework/stub_engine/ops_kernel_store/op/op.h
  24. +55
    -0
      tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc
  25. +94
    -0
      tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h
  26. +179
    -0
      tests/st/framework/stub_engine/proto/task.proto
  27. +1763
    -0
      tests/st/framework/stub_op_proto/array_ops.cc
  28. +711
    -0
      tests/st/framework/stub_op_proto/array_ops.h
  29. +392
    -0
      tests/st/framework/stub_op_proto/control_flow_ops.cc
  30. +407
    -0
      tests/st/framework/stub_op_proto/control_flow_ops.h
  31. +4633
    -0
      tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
  32. +3788
    -0
      tests/st/framework/stub_op_proto/elewise_calculation_ops.h
  33. +234
    -0
      tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc
  34. +42
    -0
      tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h
  35. +195
    -0
      tests/st/framework/stub_op_proto/util/axis_util.cc
  36. +144
    -0
      tests/st/framework/stub_op_proto/util/axis_util.h
  37. +1038
    -0
      tests/st/framework/stub_op_proto/util/common_shape_fns.cc
  38. +417
    -0
      tests/st/framework/stub_op_proto/util/common_shape_fns.h
  39. +60
    -0
      tests/st/framework/stub_op_proto/util/error_code.h
  40. +318
    -0
      tests/st/framework/stub_op_proto/util/error_util.cc
  41. +184
    -0
      tests/st/framework/stub_op_proto/util/error_util.h
  42. +73
    -0
      tests/st/framework/stub_op_proto/util/op_common_util.h
  43. +89
    -0
      tests/st/framework/stub_op_proto/util/op_log.h
  44. +258
    -0
      tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc
  45. +129
    -0
      tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h
  46. +1097
    -0
      tests/st/framework/stub_op_proto/util/util.cc
  47. +363
    -0
      tests/st/framework/stub_op_proto/util/util.h
  48. +17
    -0
      tests/st/framework/utils/assertion/graph_assertion.cc
  49. +34
    -0
      tests/st/framework/utils/assertion/graph_assertion.h
  50. +48
    -0
      tests/st/framework/utils/builder/graph_builder_utils.cc
  51. +53
    -0
      tests/st/framework/utils/builder/graph_builder_utils.h
  52. +17
    -0
      tests/st/framework/utils/builder/tensor_builder_utils.cc
  53. +22
    -0
      tests/st/framework/utils/builder/tensor_builder_utils.h
  54. +15
    -0
      tests/st/testcase/CMakeLists.txt
  55. +58
    -0
      tests/st/testcase/test_framework_dummy.cc

+ 120
- 106
CMakeLists.txt View File

@@ -39,7 +39,7 @@ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR}

option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE)

if (ENABLE_OPEN_SRC)
if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST)
set(HI_PYTHON python3)

include(cmake/external_libs/protobuf_shared.cmake)
@@ -51,118 +51,132 @@ if (ENABLE_OPEN_SRC)
include(cmake/external_libs/json.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# if D_LINK_PATH is set in environment variables, search libraries in given path
if(DEFINED ENV{D_LINK_PATH})
# D_LINK_PATH is set
set(GE_LIB_PATH $ENV{D_LINK_PATH})
set(GE_SYS_ARCH "")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
# x86 ubuntu
set(GE_SYS_ARCH "x86_64")
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
# arm euleros
set(GE_SYS_ARCH "aarch64")
add_subdirectory(tests)
else ()
if (ENABLE_OPEN_SRC)
set(HI_PYTHON python3)

include(cmake/external_libs/protobuf_shared.cmake)
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/gflags.cmake)
include(cmake/external_libs/gtest.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# if D_LINK_PATH is set in environment variables, search libraries in given path
if(DEFINED ENV{D_LINK_PATH})
# D_LINK_PATH is set
set(GE_LIB_PATH $ENV{D_LINK_PATH})
set(GE_SYS_ARCH "")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
# x86 ubuntu
set(GE_SYS_ARCH "x86_64")
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64")
# arm euleros
set(GE_SYS_ARCH "aarch64")
else()
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
set(STATIC_ACL_LIB ${GE_LIB_PATH})
find_module(slog libalog.so ${GE_LIB_PATH})
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH})
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH})
find_module(hccl libhccl.so ${GE_LIB_PATH})
find_module(adump_server libadump_server.a ${GE_LIB_PATH})
find_module(runtime libruntime.so ${GE_LIB_PATH})
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH})
find_module(resource libresource.so ${GE_LIB_PATH})
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH})
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else()
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
set(STATIC_ACL_LIB ${GE_LIB_PATH})
find_module(slog libalog.so ${GE_LIB_PATH})
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH})
find_module(msprofiler_ext libmsprofiler.a ${GE_LIB_PATH})
find_module(hccl libhccl.so ${GE_LIB_PATH})
find_module(adump_server libadump_server.a ${GE_LIB_PATH})
find_module(runtime libruntime.so ${GE_LIB_PATH})
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH})
find_module(resource libresource.so ${GE_LIB_PATH})
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH})
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
elseif(ENABLE_GE_COV OR ENABLE_GE_UT)
add_subdirectory(tests)
else()
find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train")
find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
if(PRODUCT STREQUAL "flr3")
message(FATAL_ERROR "This platform is not supported in train mode, build terminated")
endif()
elseif(PLATFORM STREQUAL "inference")
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR})
find_module(runtime libruntime.so ${ASCEND_ACL_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR})
if(PRODUCT STREQUAL "flr3")
elseif(PRODUCT STREQUAL "flr1")
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
elseif(PRODUCT STREQUAL "flr2")
# flr2 ascend_hal_stub limsprof ?
else()
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
endif()
elseif(PLATFORM STREQUAL "all")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
if(PRODUCT STREQUAL "flr3")
message(FATAL_ERROR "This platform is not supported in train mode, build terminated")
endif()
elseif(PLATFORM STREQUAL "inference")
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR})
find_module(runtime libruntime.so ${ASCEND_ACL_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR})
if(PRODUCT STREQUAL "flr3")
elseif(PRODUCT STREQUAL "flr1")
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver)
elseif(PRODUCT STREQUAL "flr2")
# flr2 ascend_hal_stub limsprof ?
else()
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")
endif()
elseif(PLATFORM STREQUAL "all")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR})
else()
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")
endif()
endif()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser)
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..)

add_subdirectory(metadef)
add_subdirectory(parser)
#add_subdirectory(metadef/graph)
#add_subdirectory(metadef/register)
elseif (ENABLE_D OR ENABLE_ACL)
# compiling with MindSpore
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common libraries
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})

if (ENABLE_D)
# training
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
endif ()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
elseif(ENABLE_MS_TESTCASES)
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common libraries
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/parser)
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..)

add_subdirectory(metadef)
add_subdirectory(parser)
#add_subdirectory(metadef/graph)
#add_subdirectory(metadef/register)
elseif (ENABLE_D OR ENABLE_ACL)
# compiling with MindSpore
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common libraries
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})

if (ENABLE_D)
# training
find_module(runtime libruntime.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(register libregister.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
endif ()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
elseif(ENABLE_MS_TESTCASES)
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common libraries
find_module(slog libalog.so ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})
find_module(static_mmpa libmmpa.a ${ASCEND_MS_RUNTIME_PATH} ${ATLAS_MS_RUNTIME_PATH})

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
else()
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser)
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
endif()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
else()
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser)
set(GE_DEPEND_DIR ${CMAKE_CURRENT_LIST_DIR}/..)
endif()
add_subdirectory(ge)

add_subdirectory(ge)
endif ()

+ 25
- 1
build.sh View File

@@ -177,6 +177,9 @@ build_graphengine()
elif [ "X$ENABLE_GE_UT" = "Xon" ]
then
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest"
elif [ "X$ENABLE_GE_ST" = "Xon" ]
then
TARGET="graph_engine_test"
elif [ "X$MINDSPORE_MODE" = "Xon" ]
then
TARGET="ge_common graph"
@@ -234,6 +237,27 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then
genhtml coverage.info
fi

if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then
#prepare engine & opskernel so
mkdir -p ${OUTPUT_PATH}/plugin/nnengine
mkdir -p ${OUTPUT_PATH}/plugin/nnengine/ge_config
mkdir -p ${OUTPUT_PATH}/plugin/opskernel
cp ${BUILD_PATH}/tests/st/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine
cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config
cp ${BUILD_PATH}/tests/st/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel
#prepare st execution bin
cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH}
#execute st testcase
RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE}
if [[ "$?" -ne 0 ]]; then
echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!"
echo -e "\033[31m${RUN_TEST_CASE}\033[0m"
exit 1;
fi
# remove plugin
rm -rf ${OUTPUT_PATH}/plugin
fi

# generate output package in tar form, including ut/st libraries/executables
generate_package()
{
@@ -337,7 +361,7 @@ generate_package()
fi
}

if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then
if [[ "X$ENABLE_GE_UT" = "Xoff" && "X$ENABLE_GE_ST" = "Xoff" && "X$MINDSPORE_MODE" = "Xoff" ]]; then
generate_package
elif [ "X$MINDSPORE_MODE" = "Xon" ]
then


+ 2
- 1
ge/ge_runtime/runtime_model.cc View File

@@ -25,6 +25,7 @@
#include "framework/common/op/op_parser_util.h"
#include "graph/types.h"
#include "task/task_factory.h"
#include "ge/common/math/math_util.h"

namespace ge {
namespace model_runner {
@@ -500,7 +501,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
}
uint64_t *buff = reinterpret_cast<uint64_t *>(const_cast<char *>(constant->weight_data.data()));
uint32_t head_len = kOffsetUnit * kStringHeadElems;
if (ge::CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) {
if (CheckInt64Uint32MulOverflow(elem_num, head_len) != SUCCESS) {
GELOGE(FAILED, "Shape size is invalid");
return false;
}


+ 1
- 1
ge/ge_runtime/task/aicpu_task.cc View File

@@ -83,7 +83,7 @@ bool AicpuTask::Distribute() {
return false;
}

GELOGI("ext info size:", ext_size);
GELOGI("ext info size: %u", ext_size);
aicpu_param_head.extInfoLength = ext_size;
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
}


+ 2
- 2
ge/ge_runtime/task/hccl_task.cc View File

@@ -130,7 +130,7 @@ bool HcclTask::SetSecondaryStream() {
Status ret;
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_);
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id);
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %u.", rt_model_handle_, master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
@@ -189,7 +189,7 @@ bool HcclTask::SetSecondaryStream() {
}
GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num);
} else {
GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(),
GELOGI("Need to create secondary stream for %s with new mainstream %u.", task_info_->op_name().c_str(),
master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {


+ 1
- 1
ge/ge_runtime/task/label_goto_task.cc View File

@@ -72,7 +72,7 @@ bool LabelGotoTask::Distribute() {
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
rt_ret = rtLabelListCpy((void**)label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: %#x", rt_ret);
return false;


+ 1
- 1
ge/ge_runtime/task/label_switch_task.cc View File

@@ -69,7 +69,7 @@ bool LabelSwitchTask::Distribute() {
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
GELOGI("Case %zu: label id %zu.", i, (size_t)label_index);
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();


+ 6
- 0
tests/st/CMakeLists.txt View File

@@ -0,0 +1,6 @@
project(graphengine_st)

include(cmake/graphengine.cmake)

add_subdirectory(framework)
add_subdirectory(testcase)

+ 249
- 0
tests/st/cmake/graphengine.cmake View File

@@ -0,0 +1,249 @@
# ---- Test coverage ----

if (ENABLE_GE_COV)
set(COVERAGE_COMPILER_FLAGS "-g --coverage -fprofile-arcs -fPIC -O0 -ftest-coverage")
set(CMAKE_CXX_FLAGS "${COVERAGE_COMPILER_FLAGS}")
endif()

# ---- Proto generate ----

file(GLOB_RECURSE PROTO_FILES CONFIGURE_DEPENDS "${GE_CODE_DIR}/metadef/proto/*.proto")

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_FILES})

# ---- File glob by group ----

file(GLOB_RECURSE METADEF_SRCS CONFIGURE_DEPENDS
"${GE_CODE_DIR}/metadef/graph/*.cc"
"${GE_CODE_DIR}/metadef/register/*.cc"
"${GE_CODE_DIR}/metadef/register/*.cpp"
"${GE_CODE_DIR}/metadef/ops/*.cc"
"${GE_CODE_DIR}/metadef/third_party/transformer/src/*.cc"
)
file(GLOB_RECURSE METADEF_REGISTER_SRCS CONFIGURE_DEPENDS
"${GE_CODE_DIR}/metadef/register/*.cc"
"${GE_CODE_DIR}/metadef/register/*.cpp"
)

file(GLOB_RECURSE PARSER_SRCS CONFIGURE_DEPENDS
"${GE_CODE_DIR}/parser/parser/common/*.cc"
)

file(GLOB_RECURSE LOCAL_ENGINE_SRC CONFIGURE_DEPENDS
"${GE_CODE_DIR}/ge/ge_local_engine/*.cc"
)

file(GLOB_RECURSE HOST_ENGINE_SRC CONFIGURE_DEPENDS
"${GE_CODE_DIR}/ge/host_cpu_engine/*.cc"
)

file(GLOB_RECURSE NN_ENGINE_SRC CONFIGURE_DEPENDS
"${GE_CODE_DIR}/ge/plugin/*.cc"
)

file(GLOB_RECURSE OFFLINE_SRC CONFIGURE_DEPENDS
"${GE_CODE_DIR}/ge/offline/*.cc"
)

file(GLOB_RECURSE GE_SRCS CONFIGURE_DEPENDS
"${GE_CODE_DIR}/ge/*.cc"
)

list(REMOVE_ITEM GE_SRCS ${LOCAL_ENGINE_SRC} ${HOST_ENGINE_SRC} ${NN_ENGINE_SRC} ${OFFLINE_SRC})

list(APPEND INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR}"
"${GE_CODE_DIR}"
"${GE_CODE_DIR}/inc"
"${GE_CODE_DIR}/metadef/inc"
"${GE_CODE_DIR}/ge"
"${GE_CODE_DIR}/ge/inc"
"${GE_CODE_DIR}/ge/ir_build"
"${GE_CODE_DIR}/metadef"
"${GE_CODE_DIR}/metadef/graph"
"${GE_CODE_DIR}/inc/external"
"${GE_CODE_DIR}/inc/framework/common"
"${GE_CODE_DIR}/metadef/inc/external"
"${GE_CODE_DIR}/metadef/inc/external/graph"
"${GE_CODE_DIR}/metadef/inc/graph"
"${GE_CODE_DIR}/inc/framework"
"${GE_CODE_DIR}/metadef/inc/common"
"${GE_CODE_DIR}/metadef/third_party"
"${GE_CODE_DIR}/metadef/third_party/transformer/inc"
"${GE_CODE_DIR}/parser"
"${GE_CODE_DIR}/parser/parser"
"${GE_CODE_DIR}/third_party/fwkacllib/inc"
"${GE_CODE_DIR}/third_party/fwkacllib/inc/cce"
"${GE_CODE_DIR}/third_party/fwkacllib/inc/ops"
"${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain"
"${GE_CODE_DIR}/tests/ut/ge"
"${GE_CODE_DIR}/tests/ut/common"
"${CMAKE_BINARY_DIR}"
"${CMAKE_BINARY_DIR}/proto/ge"
"${CMAKE_BINARY_DIR}/proto/ge/proto"
)

list(APPEND STUB_LIBS
c_sec
slog_stub
cce_ge_stub
runtime_stub
profiler_stub
#mmpa_stub
hccl_stub
error_manager_stub
ascend_protobuf
json
)

# ---- Target : Local engine ----

add_library(localengine STATIC ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS})

target_include_directories(localengine
PUBLIC
"${INCLUDE_DIRECTORIES}"
"${GE_CODE_DIR}/ge/ge_local_engine"
)

target_compile_definitions(localengine PRIVATE
google=ascend_private
)

target_compile_options(localengine PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(localengine PUBLIC
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov
)

set_target_properties(localengine PROPERTIES CXX_STANDARD 11)

# ---- Target : metadef graph ----

add_library(metadef_graph STATIC ${METADEF_SRCS} ${PROTO_SRCS} ${PROTO_HDRS})

target_include_directories(metadef_graph
PUBLIC
"${INCLUDE_DIRECTORIES}"
)

target_compile_definitions(metadef_graph PRIVATE
google=ascend_private
FMK_SUPPORT_DUMP
)

target_compile_options(metadef_graph PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(metadef_graph PUBLIC
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov
)

set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11)
# ---- Target : Host engine ----

add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC} ${PROTO_HDRS})

target_include_directories(host_cpu_engine
PUBLIC
"${INCLUDE_DIRECTORIES}"
"${GE_CODE_DIR}/ge/host_cpu_engine"
)

target_compile_definitions(host_cpu_engine PRIVATE
google=ascend_private
FMK_SUPPORT_DUMP
)

target_compile_options(host_cpu_engine PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(host_cpu_engine PUBLIC
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lmmpa -L/home/hugo/Code/ge/graphengine/build/tests/depends/mmpa -lrt -ldl -lpthread -lgcov
)

set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11)

# ---- Target : engine plugin----
#

add_library(nnengine SHARED ${NN_ENGINE_SRC})

target_include_directories(nnengine
PUBLIC
"${INCLUDE_DIRECTORIES}"
"${GE_CODE_DIR}/ge/plugin/engine"
)

target_compile_definitions(nnengine PRIVATE
google=ascend_private
)

target_compile_options(nnengine PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(nnengine PUBLIC
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} -lrt -ldl -lpthread -lgcov
)

set_target_properties(nnengine PROPERTIES CXX_STANDARD 11)

# Targe: engine_conf
add_custom_target(
engine_conf.json ALL
DEPENDS ${CMAKE_BINARY_DIR}/engine_conf.json
)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/engine_conf.json
COMMAND cp ${GE_CODE_DIR}/ge/engine_manager/engine_conf.json ${CMAKE_BINARY_DIR}/
)
# Targe: optimizer priority
add_custom_target(
optimizer_priority.pbtxt ALL
DEPENDS ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt
)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/optimizer_priority.pbtxt
COMMAND cp ${GE_CODE_DIR}/ge/opskernel_manager/optimizer_priority.pbtxt ${CMAKE_BINARY_DIR}/
)

# ---- Target : Graph engine ----

add_library(graphengine STATIC ${PARSER_SRCS} ${GE_SRCS} ${PROTO_HDRS})

target_include_directories(graphengine
PUBLIC
"${INCLUDE_DIRECTORIES}"
"${GE_CODE_DIR}/ge/host_cpu_engine"
)

target_compile_definitions(graphengine PRIVATE
google=ascend_private
FMK_SUPPORT_DUMP
)

target_compile_options(graphengine PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(graphengine PUBLIC
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS}
metadef_graph
localengine
host_cpu_engine
nnengine
mmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov
)

set_target_properties(graphengine PROPERTIES CXX_STANDARD 11)
add_dependencies(graphengine engine_conf.json optimizer_priority.pbtxt)

+ 16
- 0
tests/st/framework/CMakeLists.txt View File

@@ -0,0 +1,16 @@
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++")
#todo
file(GLOB_RECURSE stub_engine CONFIGURE_DEPENDS
"stub_engine/*.cc"
)
list(REMOVE_ITEM SOURCES ${stub_engine})

add_library(framework STATIC ${SOURCES})

target_include_directories(framework
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}
)

set_target_properties(framework PROPERTIES CXX_STANDARD 11)

target_link_libraries(framework PUBLIC graphengine)

+ 26
- 0
tests/st/framework/framework.cc View File

@@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdlib.h>
#include "framework.h"

namespace ge {
namespace st {
Status Framework::SetUp() {

}
} // namespace st
} // namespace ge

+ 33
- 0
tests/st/framework/framework.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GRAPHENGINE_LLT_ST_FRAMEWORK_H_
#define GRAPHENGINE_LLT_ST_FRAMEWORK_H_
#include <string>
#include "common/ge_inner_error_codes.h"

namespace ge {
namespace st {
class Framework {
public:
explicit Framework() {};
Status SetUp();
Status TearDown();
};
} // namespace st
}// namespace ge

#endif // GRAPHENGINE_LLT_ST_FRAMEWORK_H_

+ 259
- 0
tests/st/framework/stub_engine/CMakeLists.txt View File

@@ -0,0 +1,259 @@
set(PROTO_LIST
"${METADEF_DIR}/proto/task.proto"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST})

set(SRC_LIST
"engine/stub_engine.cc"
"ops_kernel_store/host_cpu_ops_kernel_info.cc"
"ops_kernel_store/op/op_factory.cc"
"ops_kernel_store/op/host_op.cc"
)

set(CPU_OPS_KERNEL_LIST
"ops_kernel_store/host_cpu_ops_kernel_builder.cc"
)

############ libfe.so ############
add_library(fe SHARED ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(fe PRIVATE
-Werror
-fno-common
-fvisibility=hidden
)

target_compile_definitions(fe PRIVATE
google=ascend_private
FUNC_VISIBILITY
)

target_include_directories(fe PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
${GE_CODE_DIR}/third_party/fwkacllib/inc
)

target_link_options(fe PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(fe PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
ascend_protobuf
c_sec
graph
slog
-Wl,--as-needed
)

############ atcstub/libfe.so ############
add_library(atc_fe SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS})

target_compile_options(atc_fe PRIVATE
-Werror
-fno-common
-fvisibility=hidden
)

target_compile_definitions(atc_fe PRIVATE
google=ascend_private
FUNC_VISIBILITY
)

target_include_directories(atc_fe PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge_atcstub
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
${GE_CODE_DIR}/third_party/fwkacllib/inc
)

target_link_options(atc_fe PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(atc_fe PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
ascend_protobuf
c_sec
graph
slog
-Wl,--as-needed
)

set_target_properties(atc_fe PROPERTIES
OUTPUT_NAME fe
LIBRARY_OUTPUT_DIRECTORY atclib
)

############ libhost_cpu_opskernel_builder.so ############
add_library(host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST})

target_compile_options(host_cpu_opskernel_builder PRIVATE
-Werror
-fno-common
-fvisibility=hidden
)

target_compile_definitions(host_cpu_opskernel_builder PRIVATE
google=ascend_private
FUNC_VISIBILITY
)

target_include_directories(host_cpu_opskernel_builder PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
${GE_CODE_DIR}/third_party/fwkacllib/inc
)

target_link_options(host_cpu_opskernel_builder PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(host_cpu_opskernel_builder PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
ascend_protobuf
c_sec
slog
graph
register
-Wl,--as-needed
)

############ atclib/libhost_cpu_opskernel_builder.so ############
add_library(atc_host_cpu_opskernel_builder SHARED ${CPU_OPS_KERNEL_LIST})

target_compile_options(atc_host_cpu_opskernel_builder PRIVATE
-Werror
-fno-common
-fvisibility=hidden
)

target_compile_definitions(atc_host_cpu_opskernel_builder PRIVATE
google=ascend_private
FUNC_VISIBILITY
)

target_include_directories(atc_host_cpu_opskernel_builder PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
${GE_CODE_DIR}/third_party/fwkacllib/inc
)

target_link_options(atc_host_cpu_opskernel_builder PRIVATE
-Wl,-Bsymbolic
)

target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
ascend_protobuf
c_sec
slog
graph
register
-Wl,--as-needed
)

set_target_properties(atc_host_cpu_opskernel_builder PROPERTIES
OUTPUT_NAME host_cpu_opskernel_builder
LIBRARY_OUTPUT_DIRECTORY atclib
)

############ libhost_cpu_opskernel_builder.a ############
add_library(host_cpu_opskernel_builder_static STATIC ${CPU_OPS_KERNEL_LIST})

target_compile_options(host_cpu_opskernel_builder_static PRIVATE
-Werror
-fno-common
-fvisibility=hidden
)

target_compile_definitions(host_cpu_opskernel_builder_static PRIVATE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
)

target_include_directories(host_cpu_opskernel_builder_static PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
${GE_CODE_DIR}/third_party/fwkacllib/inc
)

target_link_libraries(host_cpu_opskernel_builder_static PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
c_sec
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS fe host_cpu_opskernel_builder OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

install(TARGETS atc_fe atc_host_cpu_opskernel_builder OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/atclib
)

+ 30
- 0
tests/st/framework/stub_engine/common/constant/constant.h View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_
#define GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_

#include <string>

namespace ge {
namespace host_cpu {
// engine name
const char kHostCpuEngineName[] = "DNN_VM_HOST_CPU";
const char kHostCpuOpKernelLibName[] = "DNN_VM_HOST_CPU_OP_STORE";
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_

+ 74
- 0
tests/st/framework/stub_engine/engine/stub_engine.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "stub_engine.h"
#include <map>
#include <memory>
#include <string>
#include <securec.h>
#include "framework/common/debug/ge_log.h"
#include "common/ge/ge_util.h"
#include "host_cpu_engine/common/constant/constant.h"
#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h"

namespace fe {
AICEngine &AICEngine::Instance() {
static AICEngine instance;
return instance;
}

Status AICEngine::Initialize(const std::map<string, string> &options) {
if (ops_kernel_store_ == nullptr) {
ops_kernel_store_ = MakeShared<HostCpuOpsKernelInfoStore>();
if (ops_kernel_store_ == nullptr) {
GELOGE(FAILED, "[Create][AICEngine]Make HostCpuOpsKernelInfoStore failed.");
REPORT_INNER_ERROR("E19999", "AICEngine::Initialize failed for new AICEngine.");
return FAILED;
}
}
return SUCCESS;
}

void AICEngine::GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) {
if (ops_kernel_store_ != nullptr) {
// add buildin opsKernel to opsKernelInfoMap
ops_kernel_map[kHostCpuOpKernelLibName] = ops_kernel_store_;
}
}

void AICEngine::GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &) {
// no optimizer for host cpu engine
}

Status AICEngine::Finalize() {
ops_kernel_store_ = nullptr;
return SUCCESS;
}
} // namespace fe

ge::Status Initialize(const std::map<string, string> &options) {
return fe::AICEngine::Instance().Initialize(options);
}

void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) {
fe::AICEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map);
}

void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers) {
fe::AICEngine::Instance().GetGraphOptimizerObjs(graph_optimizers);
}

ge::Status Finalize() { return fe::AICEngine::Instance().Finalize(); }

+ 126
- 0
tests/st/framework/stub_engine/engine/stub_engine.h View File

@@ -0,0 +1,126 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_
#define GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <map>
#include <vector>
#include <memory>
#include <string>
#include "common/opskernel/ops_kernel_info_store.h"
#include "common/optimizer/graph_optimizer.h"

using OpsKernelInfoStorePtr = std::shared_ptr<ge::OpsKernelInfoStore>;
using GraphOptimizerPtr = std::shared_ptr<ge::GraphOptimizer>;

namespace ge {
namespace {
std::vector<string> extern_engine_name_vec = {"fe","rts_engine","aicpu_ascend_engine","aicpu_tf_engine",}
} // namespace
/**
* host cpu engine.
* Used for the ops which executes on host.
*/
class GE_FUNC_VISIBILITY StubEngine {
public:
/**
* get HostCpuEngine instance.
* @return HostCpuEngine instance.
*/
static StubEngine &Instance();

virtual ~StubEngine() = default;

/**
* When Ge start, GE will invoke this interface
* @return The status whether initialize successfully
*/
Status Initialize(const std::map<string, string> &options);

/**
* After the initialize, GE will invoke this interface
* to get the Ops kernel Store.
* @param ops_kernel_map The host cpu's ops kernel info
*/
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map);

/**
* After the initialize, GE will invoke this interface
* to get the Graph Optimizer.
* @param graph_optimizers The host cpu's Graph Optimizer objs
*/
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers);

/**
* When the graph finished, GE will invoke this interface
* @return The status whether initialize successfully
*/
Status Finalize();

StubEngine(const StubEngine &StubEngine) = delete;
StubEngine(const StubEngine &&StubEngine) = delete;
StubEngine &operator=(const StubEngine &StubEngine) = delete;
StubEngine &operator=(StubEngine &&StubEngine) = delete;

private:
StubEngine() = default;
OpsKernelInfoStorePtr ops_kernel_store_ = nullptr;
};
} // namespace ge

extern "C" {

/**
* When Ge start, GE will invoke this interface
* @return The status whether initialize successfully
*/
GE_FUNC_VISIBILITY ge::Status Initialize(const map<string, string> &options);

/**
* After the initialize, GE will invoke this interface to get the Ops kernel Store
* @param ops_kernel_map The host cpu's ops kernel info
*/
GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map);

/**
* After the initialize, GE will invoke this interface to get the Graph Optimizer
* @param graph_optimizers The host cpu's Graph Optimizer objs
*/
GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers);

/**
* When the graph finished, GE will invoke this interface
* @return The status whether initialize successfully
*/
GE_FUNC_VISIBILITY ge::Status Finalize();
}

#endif // GE_ST_STUB_ENGINE_HOST_CPU_ENGINE_H_

+ 114
- 0
tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.cc View File

@@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "host_cpu_ops_kernel_builder.h"
#include <memory>
#include "common/ge_inner_error_codes.h"
#include "ge/ge_api_types.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include <securec.h>
#include "framework/common/debug/ge_log.h"
#include "host_cpu_engine/common/constant/constant.h"
#include "register/ops_kernel_builder_registry.h"

namespace ge {
namespace host_cpu {
REGISTER_OPS_KERNEL_BUILDER(kHostCpuOpKernelLibName, HostCpuOpsKernelBuilder);

Status HostCpuOpsKernelBuilder::Finalize() {
return SUCCESS;
}
Status HostCpuOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) {
return SUCCESS;
}

Status HostCpuOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) {
OpDescPtr op_desc = ge_node.GetOpDesc();
if (op_desc == nullptr) {
GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null");
REPORT_INNER_ERROR("E19999", "GetOpDesc failed.");
return FAILED;
}

bool is_shape_unknown = false;
if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) {
if (is_shape_unknown) {
GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str());
return SUCCESS;
}
}

const string name = ge_node.GetName();
const string type = ge_node.GetType();
GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize());

for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) {
GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast<uint32_t>(i));
Format format = output_tensor.GetFormat();
DataType data_type = output_tensor.GetDataType();

int64_t mem_size = 0;
// If mem size has been set, no need reset.
if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) {
GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.",
name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size);
continue;
}

int64_t output_mem_size = 0;
GeShape output_shape = output_tensor.GetShape();
if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) ||
(output_mem_size < 0)) {
GELOGE(FAILED,
"[Calc][TensorMemSize] fail for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.",
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
REPORT_CALL_ERROR("E19999",
"CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.",
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.",
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(),
TypeUtils::DataTypeToSerialString(data_type).c_str());

TensorUtils::SetSize(output_tensor, output_mem_size);
if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(i), output_tensor) != GRAPH_SUCCESS) {
GELOGE(FAILED,
"[Update][OutputDesc] fail for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.",
name.c_str(), type.c_str(), i,
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str());
REPORT_CALL_ERROR("E19999", "UpdateOutputDesc failed for op[%s:%s] out[%zu] desc , format=%s, data_type=%s.",
name.c_str(), type.c_str(), i,
TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str());
return FAILED;
}
}

GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str());
return SUCCESS;
}

Status HostCpuOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) {
// no need to generate device task
return SUCCESS;
}
} // namespace host_cpu
} // namespace ge

+ 51
- 0
tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_builder.h View File

@@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include "common/opskernel/ops_kernel_builder.h"

namespace ge {
namespace host_cpu {
class GE_FUNC_VISIBILITY HostCpuOpsKernelBuilder : public OpsKernelBuilder {
public:
Status Initialize(const map<std::string, std::string> &options) override;

Status Finalize() override;

Status CalcOpRunningParam(Node &node) override;

Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override;
};
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_

+ 67
- 0
tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.cc View File

@@ -0,0 +1,67 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "host_cpu_engine/ops_kernel_store/host_cpu_ops_kernel_info.h"
#include <memory>
#include "common/constant/constant.h"
#include "ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "op/op_factory.h"

namespace ge {
namespace host_cpu {
using domi::TaskDef;
using std::map;
using std::string;
using std::vector;

Status HostCpuOpsKernelInfoStore::Initialize(const map<string, string> &options) {
GELOGI("HostCpuOpsKernelInfoStore init start.");
OpInfo default_op_info = {.engine = kHostCpuEngineName,
.opKernelLib = kHostCpuOpKernelLibName,
.computeCost = 0,
.flagPartial = false,
.flagAsync = false,
.isAtomic = false};
// Init op_info_map_
auto all_ops = OpFactory::Instance().GetAllOps();
for (auto &op : all_ops) {
op_info_map_[op] = default_op_info;
}

GELOGI("HostCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size());

return SUCCESS;
}

Status HostCpuOpsKernelInfoStore::Finalize() {
op_info_map_.clear();
return SUCCESS;
}

void HostCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; }

bool HostCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const {
if (op_desc == nullptr) {
return false;
}
return op_info_map_.count(op_desc->GetType()) > 0;
}
} // namespace host_cpu
} // namespace ge

+ 86
- 0
tests/st/framework/stub_engine/ops_kernel_store/host_cpu_ops_kernel_info.h View File

@@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY _declspec(dllexport)
#else
#define GE_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define GE_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define GE_FUNC_VISIBILITY
#endif
#endif

#include <map>
#include <string>
#include <vector>

#include "common/opskernel/ops_kernel_info_store.h"

namespace ge {
namespace host_cpu {
class GE_FUNC_VISIBILITY HostCpuOpsKernelInfoStore : public OpsKernelInfoStore {
public:
HostCpuOpsKernelInfoStore() {}
~HostCpuOpsKernelInfoStore() override = default;

/**
* Initialize related resources of the host cpu kernelinfo store
* @return status whether this operation success
*/
Status Initialize(const std::map<std::string, std::string> &options) override;

/**
* Release related resources of the host cpu kernel info store
* @return status whether this operation success
*/
Status Finalize() override;

/**
* Check to see if an operator is fully supported or partially supported.
* @param op_desc OpDesc information
* @param reason unsupported reason
* @return bool value indicate whether the operator is fully supported
*/
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override;

/**
* Returns the full operator information.
* @param infos reference of a map,
* contain operator's name and detailed information
*/
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override;

HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete;
HostCpuOpsKernelInfoStore(const HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete;
HostCpuOpsKernelInfoStore &operator=(const HostCpuOpsKernelInfoStore &ops_kernel_store) = delete;
HostCpuOpsKernelInfoStore &operator=(HostCpuOpsKernelInfoStore &&ops_kernel_store) = delete;

private:
// store op name and OpInfo key-value pair
std::map<std::string, ge::OpInfo> op_info_map_;
};
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_

+ 40
- 0
tests/st/framework/stub_engine/ops_kernel_store/op/host_op.cc View File

@@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "host_cpu_engine/ops_kernel_store/op/host_op.h"
#include "framework/common/util.h"
#include "host_cpu_engine/ops_kernel_store/op/op_factory.h"

namespace ge {
namespace host_cpu {
Status HostOp::Run() {
// no need to generate device task
return SUCCESS;
}

REGISTER_OP_CREATOR(NoOp, HostOp);
REGISTER_OP_CREATOR(Variable, HostOp);
REGISTER_OP_CREATOR(Constant, HostOp);
REGISTER_OP_CREATOR(Assign, HostOp);
REGISTER_OP_CREATOR(RandomUniform, HostOp);
REGISTER_OP_CREATOR(Add, HostOp);
REGISTER_OP_CREATOR(Mul, HostOp);
REGISTER_OP_CREATOR(ConcatV2, HostOp);
REGISTER_OP_CREATOR(Data, HostOp);
REGISTER_OP_CREATOR(Fill, HostOp);
REGISTER_OP_CREATOR(NetOutput, HostOp);
} // namespace host_cpu
} // namespace ge

+ 36
- 0
tests/st/framework/stub_engine/ops_kernel_store/op/host_op.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_

#include "host_cpu_engine/ops_kernel_store/op/op.h"

namespace ge {
namespace host_cpu {
class GE_FUNC_VISIBILITY HostOp : public Op {
public:
HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {}
~HostOp() override = default;
HostOp &operator=(const HostOp &op) = delete;
HostOp(const HostOp &op) = delete;

Status Run() override;
};
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_

+ 45
- 0
tests/st/framework/stub_engine/ops_kernel_store/op/op.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_

#include <climits>
#include <string>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "common/opskernel/ops_kernel_info_types.h"
#include "graph/node.h"

namespace ge {
namespace host_cpu {
/**
* The base class for all op.
*/
class GE_FUNC_VISIBILITY Op {
public:
Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {}
virtual ~Op() = default;
virtual Status Run() = 0;

protected:
const RunContext &run_context_;
const Node &node_;
};
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_

+ 55
- 0
tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.cc View File

@@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "host_cpu_engine/ops_kernel_store/op/op_factory.h"
#include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "graph/op_desc.h"

namespace ge {
namespace host_cpu {
OpFactory &OpFactory::Instance() {
static OpFactory instance;
return instance;
}

std::shared_ptr<Op> OpFactory::CreateOp(const Node &node, RunContext &run_context) {
auto iter = op_creator_map_.find(node.GetType());
if (iter != op_creator_map_.end()) {
return iter->second(node, run_context);
}

GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str());
return nullptr;
}

void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) {
if (func == nullptr) {
GELOGW("Func is NULL.");
return;
}

auto iter = op_creator_map_.find(type);
if (iter != op_creator_map_.end()) {
GELOGW("%s creator already exist", type.c_str());
return;
}

op_creator_map_[type] = func;
all_ops_.emplace_back(type);
}
} // namespace host_cpu
} // namespace ge

+ 94
- 0
tests/st/framework/stub_engine/ops_kernel_store/op/op_factory.h View File

@@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_

#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/ge/ge_util.h"
#include "host_cpu_engine/ops_kernel_store/op/op.h"

namespace ge {
namespace host_cpu {
using OP_CREATOR_FUNC = std::function<std::shared_ptr<Op>(const Node &, RunContext &)>;

/**
* manage all the op, support create op.
*/
class GE_FUNC_VISIBILITY OpFactory {
public:
static OpFactory &Instance();

/**
* @brief create Op.
* @param [in] node share ptr of node
* @param [in] run_context run context
* @return not nullptr success
* @return nullptr fail
*/
std::shared_ptr<Op> CreateOp(const Node &node, RunContext &run_context);

/**
* @brief Register Op create function.
* @param [in] type Op type
* @param [in] func Op create func
*/
void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func);

const std::vector<std::string> &GetAllOps() const { return all_ops_; }

bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); }

OpFactory(const OpFactory &) = delete;
OpFactory &operator=(const OpFactory &) = delete;
OpFactory(OpFactory &&) = delete;
OpFactory &operator=(OpFactory &&) = delete;

private:
OpFactory() = default;
~OpFactory() = default;

// the op creator function map
std::map<std::string, OP_CREATOR_FUNC> op_creator_map_;
std::vector<std::string> all_ops_;
};

class GE_FUNC_VISIBILITY OpRegistrar {
public:
OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) {
OpFactory::Instance().RegisterCreator(type, func);
}
~OpRegistrar() = default;

OpRegistrar(const OpRegistrar &) = delete;
OpRegistrar &operator=(const OpRegistrar &) = delete;
OpRegistrar(OpRegistrar &&) = delete;
OpRegistrar &operator=(OpRegistrar &&) = delete;
};

#define REGISTER_OP_CREATOR(type, clazz) \
std::shared_ptr<Op> Creator_##type##Op(const Node &node, RunContext &run_context) { \
return MakeShared<clazz>(node, run_context); \
} \
OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op)
} // namespace host_cpu
} // namespace ge

#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_

+ 179
- 0
tests/st/framework/stub_engine/proto/task.proto View File

@@ -0,0 +1,179 @@
/* Copyright 2021. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
KernelDefWithHandle kernel_with_handle = 40;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelDefWithHandle {
KernelContext context = 1;

uint64 handle = 10;
string dev_func = 11;
uint32 block_dim = 12;
uint32 args_size = 13;
bytes args = 14;
bytes sm_desc = 15;
string original_kernel_key = 16;
string node_info = 17;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 1763
- 0
tests/st/framework/stub_op_proto/array_ops.cc
File diff suppressed because it is too large
View File


+ 711
- 0
tests/st/framework/stub_op_proto/array_ops.h View File

@@ -0,0 +1,711 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file array_ops.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_
#define OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
namespace ge {
/**
*@brief Finds unique elements in a 1D tensor. \n
*@par Inputs:
*x: 1D tensor.
*Input "x" is a k-dimensional tensor. Inputs "num_lower" and "num_upper"
are 0D scalars. \n
*@par Attributes:
*out_idx: An optional DType from: "int32, int64". Defaults to "int32". \n
*@par Outputs:
*@li y: "x" in the unique output "y".
*@li idx: A tensor the same size as "x". The index of each value of "x". \n
*@attention Constraints:
*Unique runs on the Ascend AI CPU, which delivers poor performance. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Unique.
*/
REG_OP(Unique)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \
DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE}))
.OUTPUT(idx, TensorType({DT_INT32, DT_INT64}))
.ATTR(out_idx, Type, DT_INT32)
.OP_END_FACTORY_REG(Unique)
/**
*@brief Creates a constant tensor from a tensor-like object. This operator is used for inference.
Operator Const has the same definition as operator Constant. \n
*@par Attributes:
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n
*@par Outputs:
*y: A constant tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Const.
*/
REG_OP(Const)
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(value, Tensor, Tensor())
.OP_END_FACTORY_REG(Const)
/**
*@brief Creates a constant tensor for training. \n
*@par Attributes:
*value: Required. The value and type of the resulting tensor, and no restrictions on type. \n
*@par Outputs:
*y: The constant tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Const.
*/
REG_OP(Constant)
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(value, Tensor, Tensor())
.OP_END_FACTORY_REG(Constant)
/**
*@brief Returns a copy of the input tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Snapshot.
*/
REG_OP(Snapshot)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(Snapshot)
/**
*@brief Gives a guarantee to the runtime that the input tensor is a constant. \n
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: The input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator GuaranteeConst.
*/
REG_OP(GuaranteeConst)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(GuaranteeConst)
/**
*@brief Returns the target shape for broadcasting shapes "x1" and "x2". \n
*@par Inputs:
*@li x1: A tensor of type int32 or int64. A shape.
*@li x2: A tensor of the same type as "x1". The other shape. \n
*@par Outputs:
*y: A tensor. The broadcasted shape. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator BroadcastArgs.
*/
REG_OP(BroadcastArgs)
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(BroadcastArgs)
/**
*@brief Outputs its input tensor as is and triggers an error if a gradient is requested. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*message: Will be printed in the error at the attempt to request a gradient. \n
*@par Outputs:
*y: The input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator PreventGradient.
*/
REG_OP(PreventGradient)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(message, String, "")
.OP_END_FACTORY_REG(PreventGradient)
/**
*@brief Returns the reduction indices for computing gradients of "x1" and "x2" with broadcast. \n
*@par Inputs:
*@li x1: A tensor of type int32 or int64.
*@li x2: A tensor of type int32 or int64.
"x2" has the same type as "x1". \n
*@par Outputs:
*@li y1: A tensor. Reduction indices of "x1".
*@li y2: A tensor. Reduction indices of "x2". \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator BroadcastGradientArgs.
*/
REG_OP(BroadcastGradientArgs)
.INPUT(x1, TensorType({DT_INT32, DT_INT64}))
.INPUT(x2, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y1, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y2, TensorType({DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(BroadcastGradientArgs)
/**
*@brief Stops gradient computation. None is returned for the node where the gradient computation is stopped.
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: The input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator StopGradient.
*/
REG_OP(StopGradient)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(StopGradient)
/**
*@brief Return a tensor with the same shape and contents as input. \n
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Identity.
*/
REG_OP(Identity)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(Identity)
/**
*@brief Returns a list of tensors with the same shapes and contents as the input tensors. \n
*@par Inputs:
*x: A list of input tensors. It's a dynamic input \n
*@par Outputs:
*y: A list of Tensor objects, with the same length as the input tensor list.
It's a dynamic output. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator IdentityN.
*/
REG_OP(IdentityN)
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(IdentityN)
/**
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n
*@par Inputs:
*@li x: A tensor.
*@li axis: The dimension index at which to expand. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator ExpandDims.
*/
REG_OP(ExpandDims)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32,
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.INPUT(axis, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32,
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(ExpandDims)
/**
*@brief Inserts a dimension of 1 into a tensor's shape. Only the tensor shape is changed, without changing the data. \n
*@par Inputs:
*@li x: Original tensor.
*@li axis: List of ints. \n
*@par Outputs:
*y: Reshape tensor with same data as input. \n
*@par Third-party framework compatibility
*Compatible with the Onnx operator Unsqueeze.
*/
REG_OP(Unsqueeze)
.INPUT(x, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT32, DT_INT32, DT_UINT8, DT_BOOL}))
.ATTR(axes, ListInt, {})
.OP_END_FACTORY_REG(Unsqueeze)
/**
*@brief Reshapes a tensor. Only the tensor shape is changed, without changing the data. \n
*@par Inputs:
*@li x: A tensor.
*@li shape: A tensor. Defines the shape of the output tensor. \n
*@par Attributes:
*@li axis: An optional int32 or int64. The first dimension to reshape. Defaults to "0".
*@li num_axes: An optional int32 or int64. The extent of the reshape. Defaults to "-1". \n
*@par Outputs:
*y: A tensor. \n
*@par Attention:
*This operator cannot be directly called by the acllopExecute API. \n
*@par Third-party framework compatibility
*@li Compatible with the TensorFlow operator Reshape.
*@li Compatible with the Caffe operator Reshape.
*/
REG_OP(Reshape)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32,
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.INPUT(shape, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32,
DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(axis, Int, 0)
.ATTR(num_axes, Int, -1)
.OP_END_FACTORY_REG(Reshape)
/**
*@brief Removes dimensions of size 1 from the shape of a tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*axis: An optional list of int32 or int64. If not specified, squeezes all dimensions of size 1. If specified, only squeezes the dimensions listed. It is an error to squeeze a dimension that is not 1. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Squeeze.
*/
REG_OP(Squeeze)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(axis, ListInt, {})
.OP_END_FACTORY_REG(Squeeze)
/**
*@brief Returns an integer representing the rank of input tensor. The rank of a tensor is the number of indices required to uniquely select each element of the tensor, that is, the dimension size of the tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: A tensor. The rank of input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Rank.
*/
REG_OP(Rank)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_INT32}))
.OP_END_FACTORY_REG(Rank)
/**
*@brief Returns the size of a tensor, that is, an integer of the number of elements of the tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*out_type: An optional int32 or int64. The output data type. Defaults to "int32". \n
*@par Outputs:
*y: A tensor. The size of the input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Size.
*/
REG_OP(Size)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_INT32,DT_INT64}))
.ATTR(dtype, Int, DT_INT32)
.OP_END_FACTORY_REG(Size)
/**
*@brief Input data for other operators. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*index: Index of the input tensor.The data type must be int32 or int64.
Assume that net has three data nodes, one should be set 0, another should
be set 1, and the left should be set 2. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the Caffe operator Data.
*/
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)
/**
*@brief Inserts a placeholder for a tensor that will be always fed. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*@li peerIndex: An integer type. The index of the corresponding "end" node connected to.
*@li parentId: A string, used to check if the nodes are from the saved parent node.
*@li parentOpType: A string. Op type of the original node.
*@li anchorIndex: An integer, used to check if the node is from the saved anchor. \n
*@par Outputs:
*y: The created placeholder tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator PlaceHolder.
*/
REG_OP(PlaceHolder)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(peerIndex, Int, 0) // the index of the corresponding 'end' node it's connected to
.ATTR(parentId, String, "") // check if these node are from save parent node
.ATTR(parentOpType, String, "") // op type of original node
.ATTR(anchorIndex, Int, 0) // check if these node are from save anchor
.OP_END_FACTORY_REG(PlaceHolder)
/**
*@brief Inserts a placeholder with default value for a tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*@li dtype: data type of tensor.
*@li shape: tensor shape. \n
*@par Outputs:
*y: The created placeholder tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator PlaceholderWithDefault.
*/
REG_OP(PlaceholderWithDefault)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.REQUIRED_ATTR(shape, ListInt)
.OP_END_FACTORY_REG(PlaceholderWithDefault)
/**
*@brief Reads and returns the value of the input variable tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator ReadVariableOp.
*/
REG_OP(ReadVariableOp)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(dtype, Int, DT_INT32)
.OP_END_FACTORY_REG(ReadVariableOp)
/**
*@brief Mark outputs of one sub graph which partitioned by engine type.
*@par Inputs:
*x: A tensor. \n
*@par Outputs:
*y: A tensor. \n
*@par Attributes:
*@li peerIndex: The index of the corresponding 'placeholder' node it's connected to.
*@li parentOpType: Op type of original node.
*@par Restrictions:
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
*/
REG_OP(End)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(peerIndex, Int, 0)
.ATTR(parentOpType, String, "")
.OP_END_FACTORY_REG(End)
/**
*@brief Operations for writing summary data, for use in analysis and visualization.
*@par Inputs:
* One input:
*x: Collections of summary data.
*@par Restrictions:
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
*/
REG_OP(Summary)
.INPUT(x, TensorType::ALL())
.OP_END_FACTORY_REG(Summary)
/**
*@brief Returns the shape of a tensor. \n
*@par Inputs:
*x: A tensor. \n
*@par Attributes:
*dtype: An optional int32 or int64. The output data type. Defaults to int32. \n
*@par Outputs:
*y: A tensor. The shape of the input tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Size.
*/
REG_OP(Shape)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.ATTR(dtype, Int, DT_INT32)
.OP_END_FACTORY_REG(Shape)
/**
*@brief Returns shape of tensors. \n
*@par Inputs:
*x: A list of input tensors. It's a dynamic input. \n
*@par Attributes:
*dtype: An optional int32 or int64. The output data type. Defaults to "int32". \n
*@par Outputs:
*y: A list of tensors with the same length as the input list of tensors.
It's a dynamic output. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator ShapeN.
*/
REG_OP(ShapeN)
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.DYNAMIC_OUTPUT(y, TensorType({DT_INT32, DT_INT64}))
.ATTR(dtype, Int, DT_INT32)
.OP_END_FACTORY_REG(ShapeN)
/**
*@brief Creates a tensor with the given "shape" and "dtype". \n
*@par Inputs:
*shape: The shape of the output tensor. \n
*@par Attributes:
*@li dtype: Optional. The data type of the output tensor. Defaults to "int32".
*@li init: An optional bool. If true, initializes the returned tensor with the default value of "dtype". Defaults to "false". \n
*@par Outputs:
*y: A tensor. \n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Empty.
*/
REG_OP(Empty)
.INPUT(shape, TensorType({DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(dtype, Int, DT_INT32)
.ATTR(init, Bool, 0)
.OP_END_FACTORY_REG(Empty)
/**
*@brief Returns locations of nonzero / true values in a tensor. \n
*@par Inputs:
*Including:
*x: A Tensor. Must be one of the following types:
DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16,
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL. \n
*@par Outputs:
*y: A Tensor of type DT_INT64. \n
*@attention Constraints:
*Where runs on the Ascend AI CPU, which delivers poor performance.\n
*@par Third-party framework compatibility
*Compatible with the TensorFlow operator Where.
*/
REG_OP(Where)
.INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \
DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_INT64}))
.OP_END_FACTORY_REG(Where)
/**
*@brief Change the shape of output according to the attr outShape
*
*@par Inputs:
*x: A Tensor. \n
*@par Outputs:
*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. \n
*@par Attributes:
*outShape: The shape of output will be inferred according to the attribute
*/
REG_OP(TransShape)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(outShape,ListInt ,{})
.OP_END_FACTORY_REG(TransShape);
/**
* @brief sort_v2.
* @par Inputs:
* @li x: An ND tensor of type float16.
* @par Attributes:
* @li axis: An optional int. The dimension to sort along. This value defaults to -1.
* @li descending: An optional bool. Controls the sorting order (ascending or descending). This value defaults to False.
* @par Outputs:
* @li y: An ND tensor of type float16.
* @attention Constraints:
* @li Axis should select the last dim.
* @li When the sorting data is less than 150K, it is recommended to use this tbe ops,
and the descending performance is better than the ascending.
* @li The upper limit of data on Ascend910 is 2000K.
*/
REG_OP(SortV2)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE}))
.ATTR(axis, Int, -1)
.ATTR(descending, Bool, false)
.OP_END_FACTORY_REG(SortV2)
/**
* @brief Expand the input tensor to a compatible shape. \n
* @par Inputs:
* One inputs, including:
* @li x: A Tensor. Must be one of the following types:
* float16, float32, int32, int8 ,uint8. \n
* @li shape: A Tensor to specify the shape that the input tensor expanded to. \n
* @par Outputs:
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n
* @par Third-party framework compatibility
* Compatible with the ONNX operator Expand.
*/
REG_OP(Expand)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8}))
.INPUT(shape, TensorType({DT_INT16, DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8}))
.OP_END_FACTORY_REG(Expand)
/**
* @brief Expand the input tensor to a compatible shape. \n
* @par Inputs:
* One inputs, including:
* @li x: A Tensor. Must be one of the following types:
* float16, float32, int32, int8 ,uint8. \n
* @par Attributes:
* @li shape: A required listInt to specify the shape that the input tensor expanded to. \n
* @par Outputs:
* @li y: A Tensor. Has the same type as "x", and the shape specified by input and attr shape \n
* @par Third-party framework compatibility
* Compatible with the ONNX operator Expand.
*/
REG_OP(ExpandD)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8}))
.REQUIRED_ATTR(shape, ListInt)
.OP_END_FACTORY_REG(ExpandD)
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_INC_ARRAY_OPS_H_

+ 392
- 0
tests/st/framework/stub_op_proto/control_flow_ops.cc View File

@@ -0,0 +1,392 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file control_flow_ops.cpp
* \brief
*/
#include "control_flow_ops.h"
#include "./util/common_shape_fns.h"
#include "./util/error_util.h"
#include "util/util.h"
namespace ge {
namespace {
graphStatus MergeInferImpl(Operator& op) {
TensorDesc td = op.GetOutputDesc("value_index");
TensorDesc td_y = op.GetOutputDesc("y");
td.SetShape(ge::Shape());
td.SetDataType(DT_INT32);
auto ret = op.UpdateOutputDesc("value_index", td);
if (ret != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
// check N of "x" >= 1
size_t in_num = op.GetInputsSize();
if (in_num < 1) {
string reason = "inputs size[" + std::to_string(in_num) + "] must be greater than or equal to 1";
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "input", reason);
return GRAPH_FAILED;
} else if (in_num == 2) {
// Check is loop_merge, order of InferShape: Enter->Merge->NextIteration
// So when processing InferShape on Merge op, shape & datatype of NextIteration op is set as default.
// Therefore, shape & datatype of Merge op should be set as the Enter op.
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType();
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
bool not_handle_flag0 = (x0_type == DT_FLOAT) && (x0_dims.size() == 0);
auto x1_type = op.GetDynamicInputDesc("x", 1).GetDataType();
auto x1_dims = op.GetDynamicInputDesc("x", 1).GetShape().GetDims();
bool not_handle_flag1 = (x1_type == DT_FLOAT) && (x1_dims.size() == 0);
if ((x0_type != x1_type) && (not_handle_flag0 || not_handle_flag1)) {
if (not_handle_flag0) {
td_y.SetShape(ge::Shape(x1_dims));
td_y.SetDataType(x1_type);
} else {
td_y.SetShape(ge::Shape(x0_dims));
td_y.SetDataType(x0_type);
}
(void)op.UpdateOutputDesc("y", td_y);
return GRAPH_SUCCESS;
}
}
// check "x" be same type
auto x0_type = op.GetDynamicInputDesc("x", 0).GetDataType();
for (size_t i = 1; i < op.GetInputsSize(); i++) {
auto xi_type = op.GetDynamicInputDesc("x", i).GetDataType();
if (xi_type != x0_type) {
string reason = "x[0]'s dtype[" + std::to_string(x0_type) + "] must be equal to x[" + std::to_string(i) +
"]'s dtype[" + std::to_string(xi_type) + "]";
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", reason);
return GRAPH_FAILED;
}
}
// infer "y" be unknown shape
auto x0_dims = op.GetDynamicInputDesc("x", 0).GetShape().GetDims();
bool x0_unknown = (x0_dims.size() == 1) && (x0_dims[0] == 0);
if (x0_unknown) {
Shape unknown_shape(ge::UNKNOWN_SHAPE);
td_y.SetShape(unknown_shape);
td_y.SetDataType(x0_type);
(void)op.UpdateOutputDesc("y", td_y);
return GRAPH_SUCCESS;
}
// find the input with the max size from all inputs, and set it's data type/shape to the output
std::map<int64_t, size_t> size_to_index;
for (size_t i = 0; i < op.GetInputsSize(); i++) {
auto xi_dims = op.GetDynamicInputDesc("x", i).GetShape().GetDims();
bool xi_unknown = (xi_dims.size() == 1) && (xi_dims[0] == 0);
if (xi_unknown) {
continue;
}
int64_t size = static_cast<int64_t>(GetSizeByDataType(op.GetDynamicInputDesc("x", i).GetDataType()));
if (size < 0) {
continue;
}
if (!xi_dims.empty()) {
for (auto& dim : xi_dims) {
if (dim <= 0) {
size = -1;
break;
}
if (size != 0 && INT64_MAX / size < dim) {
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dim", "the dim size is overflow");
return GRAPH_FAILED;
}
size *= dim;
}
if (size < 0) {
continue;
}
}
if (size_to_index.count(size) == 0) {
size_to_index[size] = i;
}
}
if (size_to_index.empty()) {
return GRAPH_FAILED;
}
auto index = size_to_index.rbegin()->second;
td_y.SetShape(ge::Shape(op.GetDynamicInputDesc("x", index).GetShape().GetDims()));
td_y.SetDataType(op.GetDynamicInputDesc("x", index).GetDataType());
(void)op.UpdateOutputDesc("y", td_y);
return GRAPH_SUCCESS;
}
graphStatus SwitchInferImpl(Operator& op) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto data_desc = op_desc->MutableInputDesc("data");
auto pred_desc = op_desc->MutableInputDesc("pred");
auto output_false_desc = op_desc->MutableOutputDesc("output_false");
auto output_true_desc = op_desc->MutableOutputDesc("output_true");
std::vector<std::pair<int64_t, int64_t>> data_range;
data_desc->GetShapeRange(data_range);
// check "pred" scalar type be bool
auto pred_dims = pred_desc->GetShape().GetDims();
if (pred_dims.size() != 0) {
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "pred dims", "pred should be a scalar");
return GRAPH_FAILED;
}
DataType pred_type = pred_desc->GetDataType();
if (pred_type != DT_BOOL) {
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "pred should be bool type");
return GRAPH_FAILED;
}
DataType data_type = data_desc->GetDataType();
auto data_dims = data_desc->GetShape().GetDims();
output_false_desc->SetShapeRange(data_range);
output_true_desc->SetShapeRange(data_range);
output_false_desc->SetShape(GeShape(data_dims));
output_false_desc->SetOriginShape(GeShape(data_dims));
output_true_desc->SetShape(GeShape(data_dims));
output_true_desc->SetOriginShape(GeShape(data_dims));
output_false_desc->SetDataType(data_type);
output_true_desc->SetDataType(data_type);
auto context = op.GetInferenceContext();
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes();
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) {
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0);
std::vector<ShapeAndType> grad_handle_shape_and_type;
grad_handle_shape_and_type.reserve(1);
grad_handle_shape_and_type.emplace_back(shape_and_type);
std::vector<std::vector<ShapeAndType>> shapes_and_types(2);
shapes_and_types[0] = grad_handle_shape_and_type;
shapes_and_types[1] = grad_handle_shape_and_type;
context->SetOutputHandleShapesAndTypes(shapes_and_types);
}
return GRAPH_SUCCESS;
}
graphStatus EnterInferImpl(Operator& op) {
auto op_desc = OpDescUtils::GetOpDescFromOperator(op);
auto input_desc_x = op_desc->MutableInputDesc("x");
auto output_desc_y = op_desc->MutableOutputDesc("y");
std::vector<std::pair<int64_t, int64_t>> x_range;
std::vector<std::pair<int64_t, int64_t>> y_range;
input_desc_x->GetShapeRange(x_range);
auto input_dims = input_desc_x->MutableShape().GetDims();
DataType input_type = input_desc_x->GetDataType();
output_desc_y->SetShape(ge::GeShape(input_dims));
output_desc_y->SetOriginShape(ge::GeShape(input_dims));
output_desc_y->SetDataType(input_type);
if (!x_range.empty()) {
output_desc_y->SetShapeRange(x_range);
}
auto context = op.GetInferenceContext();
std::vector<std::vector<ShapeAndType>> in_shapes_and_types = context->GetInputHandleShapesAndTypes();
if ((!in_shapes_and_types.empty()) && (!in_shapes_and_types.at(0).empty())) {
ShapeAndType shape_and_type = in_shapes_and_types.at(0).at(0);
std::vector<ShapeAndType> grad_handle_shape_and_type;
grad_handle_shape_and_type.reserve(1);
grad_handle_shape_and_type.emplace_back(shape_and_type);
std::vector<std::vector<ShapeAndType>> shapes_and_types(1);
shapes_and_types[0] = grad_handle_shape_and_type;
context->SetOutputHandleShapesAndTypes(shapes_and_types);
}
return GRAPH_SUCCESS;
}
graphStatus PassThroughInferImpl(Operator& op, const std::string& in_name, const std::string& out_name) {
auto input_dims = op.GetInputDesc(in_name).GetShape().GetDims();
DataType input_type = op.GetInputDesc(in_name).GetDataType();
TensorDesc tensordesc_output = op.GetOutputDesc(out_name);
tensordesc_output.SetShape(ge::Shape(input_dims));
tensordesc_output.SetDataType(input_type);
(void)op.UpdateOutputDesc(out_name, tensordesc_output);
return GRAPH_SUCCESS;
}
graphStatus LoopCondInferImpl(Operator& op) {
auto input_dims = op.GetInputDesc("x").GetShape().GetDims();
if (input_dims.size() != 0) {
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "x dims", "x should be a scalar");
return GRAPH_FAILED;
}
TensorDesc tensordesc_output = op.GetOutputDesc("y");
tensordesc_output.SetShape(ge::Shape(input_dims));
DataType input_type = op.GetInputDesc("x").GetDataType();
if (input_type != DT_BOOL) {
GeInfershapeErrReport(op.GetName(), op.GetOpType(), "dtype", "x should be bool type");
return GRAPH_FAILED;
}
tensordesc_output.SetDataType(input_type);
(void)op.UpdateOutputDesc("y", tensordesc_output);
return GRAPH_SUCCESS;
}
} // namespace
IMPLEMT_INFERFUNC(Merge, MergeInfer) {
return MergeInferImpl(op);
}
INFER_FUNC_REG(Merge, MergeInfer);
IMPLEMT_INFERFUNC(RefMerge, RefMergeInfer) {
return MergeInferImpl(op);
}
INFER_FUNC_REG(RefMerge, RefMergeInfer);
IMPLEMT_INFERFUNC(Switch, SwitchInfer) {
return SwitchInferImpl(op);
}
INFER_FUNC_REG(Switch, SwitchInfer);
IMPLEMT_INFERFUNC(RefSwitch, RefSwitchInfer) {
return SwitchInferImpl(op);
}
INFER_FUNC_REG(RefSwitch, RefSwitchInfer);
IMPLEMT_INFERFUNC(SwitchN, SwitchNInfer) {
return GRAPH_SUCCESS;
}
INFER_FUNC_REG(SwitchN, SwitchNInfer);
IMPLEMT_INFERFUNC(Enter, EnterInfer) {
return EnterInferImpl(op);
}
INFER_FUNC_REG(Enter, EnterInfer);
IMPLEMT_INFERFUNC(RefEnter, RefEnterInfer) {
return PassThroughInferImpl(op, "x", "y");
}
INFER_FUNC_REG(RefEnter, RefEnterInfer);
IMPLEMT_INFERFUNC(LoopCond, LoopCondInfer) {
return LoopCondInferImpl(op);
}
INFER_FUNC_REG(LoopCond, LoopCondInfer);
IMPLEMT_INFERFUNC(NextIteration, NextIterationInfer) {
return PassThroughInferImpl(op, "x", "y");
}
INFER_FUNC_REG(NextIteration, NextIterationInfer);
IMPLEMT_INFERFUNC(RefNextIteration, RefNextIterationInfer) {
return PassThroughInferImpl(op, "x", "y");
}
INFER_FUNC_REG(RefNextIteration, RefNextIterationInfer);
IMPLEMT_INFERFUNC(Exit, ExitInfer) {
return PassThroughInferImpl(op, "x", "y");
}
INFER_FUNC_REG(Exit, ExitInfer);
IMPLEMT_INFERFUNC(RefExit, RefExitInfer) {
return PassThroughInferImpl(op, "x", "y");
}
INFER_FUNC_REG(RefExit, RefExitInfer);
// ----------------MapIndex-------------------
IMPLEMT_VERIFIER(MapIndex, MapIndexVerify) {
return GRAPH_SUCCESS;
}
IMPLEMT_COMMON_INFERFUNC(MapIndexInferShape) {
OP_LOGI("MapIndex", "infer shape begin---");
auto x_shape = op.GetInputDesc("x").GetShape().GetDims();
if (x_shape.empty()) {
OP_LOGE(op.GetName().c_str(), "x_shape is empty");
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "x_shape is empty");
return GRAPH_FAILED;
}
int64_t x_length = x_shape[0];
auto data_seq_shape = op.GetInputDesc("data_seq").GetShape().GetDims();
if (data_seq_shape.empty()) {
OP_LOGE(op.GetName().c_str(), "data_seq_shape is empty");
OpsOneInputShapeErrReport(op.GetName().c_str(), "data_seq", "data_seq_shape is empty");
return GRAPH_FAILED;
}
int64_t data_seq_length = data_seq_shape[0];
if (x_length > 8 || x_length == 0) {
OP_LOGE(op.GetName().c_str(), "the length of x should be less than or equal to 8");
OpsOneInputShapeErrReport(op.GetName().c_str(), "x", "the length of x should be less than or equal to 8 and not 0");
return GRAPH_FAILED;
}
if (data_seq_length % x_length != 0) {
OP_LOGE(op.GetName().c_str(), "the length of data_seq must be multiple of the length of x");
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x",
"the length of data_seq must be multiple of the length of x");
return GRAPH_FAILED;
}
if (data_seq_length / x_length > 100) {
OP_LOGE(op.GetName().c_str(), "data_seq_length / x_length should be be less than or equal to 100");
OpsTwoInputShapeErrReport(op.GetName().c_str(), "data_seq", "x",
"data_seq_length / x_length should be be less than or equal to 100");
return GRAPH_FAILED;
}
auto level_index_shape = op.GetInputDesc("level_index").GetShape().GetDims();
if (!level_index_shape.empty()) {
int64_t level_index_length = level_index_shape[0];
if (level_index_length != (data_seq_length / x_length)) {
OP_LOGE(op.GetName().c_str(),
"the length of level_index must be equal to "
"the length of data_seq divided by the length of x");
OpsOneInputShapeErrReport(op.GetName().c_str(), "level_index",
"the length of level_index must be equal to "
"the length of data_seq divided by the length of x");
return GRAPH_FAILED;
}
}
TensorDesc y_desc = op.GetOutputDesc("y");
y_desc.SetShape(ge::Shape());
y_desc.SetDataType(ge::DT_INT32);
(void)op.UpdateOutputDesc("y", y_desc);
return GRAPH_SUCCESS;
}
COMMON_INFER_FUNC_REG(MapIndex, MapIndexInferShape);
VERIFY_FUNC_REG(MapIndex, MapIndexVerify);
} // namespace ge

+ 407
- 0
tests/st/framework/stub_op_proto/control_flow_ops.h View File

@@ -0,0 +1,407 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file control_flow_ops.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_
#define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_
#include "graph/operator_reg.h"
#include "graph/operator.h"
namespace ge {
/**
*@brief Forwards the value of an available tensor from input "x" to output "y".
* Merge waits for at least one of the input tensors to become available.
* It is usually combined with Switch to implement branching.
* Merge forwards the first tensor to become available to output "y",
* and sets "value_index" the index of the tensor in inputs . \n
*@par Inputs:
*x: The input tensors, one of which will become available.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n
*@par Outputs:
*@li y: The available tensor. Has the same type as "x".
*@li value_index: A scalar of type int32, for the index of the chosen input
* tensor . \n
*@see Switch()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator Merge.
*/
REG_OP(Merge)
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(value_index, TensorType({DT_INT32}))
.OP_END_FACTORY_REG(Merge)
/**
*@brief Forwards the value of an available tensor from input "x" to output "y".
* Merge waits for at least one of the input tensors to become available.
* It is usually combined with Switch to implement branching.
* Merge forwards the first tensor to become available to output "y",
* and sets "value_index" the index of the tensor in inputs . \n
*@par Inputs:
*x: The input tensors, one of which will become available.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n
*@par Outputs:
*@li y: The available tensor. Has the same type as "x".
*@li value_index: A scalar of type int32, for the index of the chosen input
* tensor . \n
*@see Switch() | Merge()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator RefMerge.
*/
REG_OP(RefMerge)
.DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(value_index, TensorType({DT_INT32}))
.OP_END_FACTORY_REG(RefMerge)
/**
*@brief Forwards "data" to the output port determined by "pred".
* If "pred" is "true", the data input is forwarded to "output_true".
* Otherwise, the data is forwarded to "output_false" . \n
*@par Inputs:
*@li data: The tensor to be forwarded. \ n
* Must be one of the following types: float16, float32, float64,
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
*@li pred: A boolean scalar. The output port that will receive data . \n
*@par Outputs:
*@li output_false: If "pred" is "false", data will be forwarded to this output.
* Has the same type as "data".
*@li output_true: If "pred" is "true", data will be forwarded to this output.
* Has the same type as "data" . \n
*@see Merge()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator Switch.
*/
REG_OP(Switch)
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.INPUT(pred, TensorType({DT_BOOL}))
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(Switch)
/**
*@brief Forwards "data" to the output port determined by "pred".
* If "pred" is "true", the data input is forwarded to "output_true".
* Otherwise, the data is forwarded to "output_false" . \n
*@par Inputs:
*@li data: The ref tensor to be forwarded.
* Must be one of the following types: float16, float32, float64,
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
*@li pred: A boolean scalar. The output port that will receive data . \n
*@par Outputs:
*@li output_false: If "pred" is "false", data will be forwarded to this output.
* Has the same type as "data".
*@li output_true: If "pred" is "true", data will be forwarded to this output.
* Has the same type as "data" . \n
*@see Merge() | Switch()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator RefSwitch.
*/
REG_OP(RefSwitch)
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.INPUT(pred, TensorType({DT_BOOL}))
.OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(RefSwitch)
/**
*@brief Forwards "data" to the output port determined by "pred_value" . \n
*@par Inputs:
*@li data: The tensor to be forwarded. \ n
* Must be one of the following types: float16, float32, float64,
* int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
*@li pred_value: A int64 tensor which determines the output port that will receive data . \n
*@par Outputs:
*output: The output tensors, one of which will become available.
* Has the same type as "data".
*/
REG_OP(SwitchN)
.INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.INPUT(pred_value, TensorType({DT_INT64}))
.DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(SwitchN)
/**
*@brief Creates or finds a child frame, and makes "x" available to the child
* frame. This op is used together with Exit to create loops in the graph.
* The Executor uses the unique "frame_name" to identify frames.
* If "is_constant" is "true", output "y" is a constant in the child
* frame; otherwise it may be changed in the child frame . \n
*@par Inputs:
*x: The tensor to be made available to the child frame.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Attributes:
*@li frame_name: A required string. The name of the child frame.
*@li is_constant: A required bool. If true, the output is constant in
* the child frame . \n
*@par Outputs:
*y: A Tensor. Has the same type as "x" . \n
*@see Exit()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator Enter.
*/
REG_OP(Enter)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.REQUIRED_ATTR(frame_name, String)
.REQUIRED_ATTR(is_constant, Bool)
.OP_END_FACTORY_REG(Enter)
/**
*@brief Creates or finds a child frame, and makes "x" available to the child
* frame. This op is used together with Exit to create loops in the graph.
* The Executor uses the unique "frame_name" to identify frames.
* If "is_constant" is "true", output "y" is a constant in the child
* frame; otherwise it may be changed in the child frame . \n
*@par Inputs:
*x: The tensor to be made available to the child frame.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Attributes:
*@li frame_name: A required string. The name of the child frame.
*@li is_constant: A required bool. If true, the output is constant in
* the child frame . \n
*@par Outputs:
*y: A tensor. Has the same type as "x" . \n
*@see Exit() | Enter()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator RefEnter.
*/
REG_OP(RefEnter)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.REQUIRED_ATTR(frame_name, String)
.REQUIRED_ATTR(is_constant, Bool)
.OP_END_FACTORY_REG(RefEnter)
/**
*@brief Forwards the input to the output. This op represents the loop
* termination condition . \n
*@par Inputs:
*x: A boolean scalar. The condition of the Switch op . \n
*@par Outputs:
*y: The tensor "x" . \n
*@see Switch()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator LoopCond.
*/
REG_OP(LoopCond)
.INPUT(x, TensorType({DT_BOOL}))
.OUTPUT(y, TensorType({DT_BOOL}))
.OP_END_FACTORY_REG(LoopCond)
/**
*@brief Makes the input available to the next iteration . \n
*@par Inputs:
*x: The tensor to be made available to the next iteration.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Outputs:
*y: A Tensor. Has the same type as "x" . \n
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator NextIteration.
*/
REG_OP(NextIteration)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(NextIteration)
/**
*@brief Makes the input available to the next iteration . \n
*@par Inputs:
*x: The tensor to be made available to the next iteration.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Outputs:
*y: A tensor. Has the same type as "x" . \n
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator RefNextIteration.
*/
REG_OP(RefNextIteration)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(RefNextIteration)
/**
*@brief Exits the current frame to its parent frame . \n
*@par Inputs:
*x: The tensor to be made available to the parent frame.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Outputs:
*y: A Tensor. Has the same type as "x" . \n
*@see Enter()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator Exit.
*/
REG_OP(Exit)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(Exit)
/**
*@brief Exits the current frame to its parent frame . \n
*@par Inputs:
*x: The tensor to be made available to the parent frame.
* Must be one of the following types: float16, float32, float64, int8,
* int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
*@par Outputs:
*y: A tensor. Has the same type as "x" . \n
*@see Enter() | Exit()
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator RefExit.
*/
REG_OP(RefExit)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
DT_UINT64, DT_BOOL}))
.OP_END_FACTORY_REG(RefExit)
/**
*@brief Only useful as a placeholder for control edges.
* It is similar to a no-op that always produces a live control output
* even when some control inputs are dead . \n
*@par Third-party framework compatibility
*@Compatible with the TensorFlow operator ControlTrigger.
*/
REG_OP(ControlTrigger)
.OP_END_FACTORY_REG(ControlTrigger)
/**
*@brief Returns index of shape in the map.
*@par Inputs:
* Three inputs, including:
*@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8.
*@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried.
*@li level_index: One dimensional tensore of type int32, specifying secondary index. \n
*@par Outputs:
*@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map.
*@par Third-party framework compatibility
* It is a custom operator. It has no corresponding operator in Caffe.
*/
REG_OP(MapIndex)
.INPUT(x, TensorType({DT_INT32}))
.INPUT(data_seq, TensorType({DT_INT32}))
.OPTIONAL_INPUT(level_index, TensorType({DT_INT32}))
.OUTPUT(y, TensorType({DT_INT32}))
.OP_END_FACTORY_REG(MapIndex)
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_

+ 4633
- 0
tests/st/framework/stub_op_proto/elewise_calculation_ops.cc
File diff suppressed because it is too large
View File


+ 3788
- 0
tests/st/framework/stub_op_proto/elewise_calculation_ops.h
File diff suppressed because it is too large
View File


+ 234
- 0
tests/st/framework/stub_op_proto/util/array_ops_shape_fns.cc View File

@@ -0,0 +1,234 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file array_ops_shape_fns.cpp
* \brief
*/
#include "array_ops_shape_fns.h"
#include "graph/types.h"
#include "op_log.h"
#include "error_util.h"
#include "common_shape_fns.h"
#include "axis_util.h"
namespace ge {
static graphStatus PadKnown(Operator& op, const Tensor& paddings_tensor, const int64_t input_dim_num) {
TensorDesc paddings_tensor_desc = paddings_tensor.GetTensorDesc();
DataType data_type = paddings_tensor_desc.GetDataType();
std::vector<int64_t> data;
// every dim has 2 element
int64_t element_num = input_dim_num * 2;
data.reserve(element_num);
if (data_type == DT_INT32) {
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData());
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < element_num,
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED);
for (int64_t i = 0; i < element_num; ++i) {
data.push_back(static_cast<int64_t>(paddings_data[i]));
}
} else if (data_type == DT_INT64) {
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData());
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < element_num,
OP_LOGE(op.GetName().c_str(), "invalid padding data."), return GRAPH_FAILED);
for (int64_t i = 0; i < element_num; ++i) {
data.push_back(paddings_data[i]);
}
} else {
string err_msg = ConcatString("paddings data type invalid, ", "should be DT_INT32 or DT_INT64");
InferShapeOtherErrReport(op.GetName(), err_msg);
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str());
return GRAPH_FAILED;
}
auto dims = op.GetInputDesc(0).GetShape().GetDims();
std::vector<int64_t> output_dims(input_dim_num, UNKNOWN_DIM);
if (dims != UNKNOWN_SHAPE) {
output_dims.assign(dims.begin(), dims.end());
}
for (size_t i = 0; i < data.size(); i += 2) {
if ((data[i] < 0) || (data[i + 1] < 0)) {
std::string err_msg = ConcatString("paddings", DebugString(data), " must be non-negative");
InferShapeOtherErrReport(op.GetName(), err_msg);
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str());
return GRAPH_FAILED;
}
graphStatus status = Add(output_dims[i / 2], data[i] + data[i + 1], output_dims[i / 2]);
if (status != GRAPH_SUCCESS) {
std::string err_msg = ConcatString("the sum input[0] shape", DebugString(dims), " and input[1] value",
DebugString(data), " must be non-negative");
OP_LOGE(op.GetName().c_str(), "%s", err_msg.c_str());
return GRAPH_FAILED;
}
}
auto output_desc = op.GetOutputDesc("y");
output_desc.SetShape(Shape(output_dims));
return op.UpdateOutputDesc("y", output_desc);
}
graphStatus PadShapeFn(Operator& op) {
Shape paddings;
int64_t input_dim_num;
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str());
if (status != GRAPH_SUCCESS) {
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D");
return GRAPH_FAILED;
}
status = WithValue(paddings.GetDim(1), 2, input_dim_num, op.GetName().c_str());
if (status != GRAPH_SUCCESS) {
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()),
ConcatString(2, " of dim[1]"));
return GRAPH_FAILED;
}
Shape input;
int64_t dim0 = paddings.GetDim(0);
if (dim0 != UNKNOWN_DIM) {
status = WithRank(op.GetInputDesc(0), dim0, input, op.GetName().c_str());
if (status != GRAPH_SUCCESS) {
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D"));
return GRAPH_FAILED;
}
} else if (op.GetInputDesc(0).GetShape().GetDim(0) != 0) {
status = WithValue(dim0, op.GetInputDesc(0).GetShape().GetDimNum(), input_dim_num, op.GetName().c_str());
if (status != GRAPH_SUCCESS) {
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(dim0, "D"));
return GRAPH_FAILED;
}
}
TensorDesc output_desc = op.GetOutputDesc("y");
Tensor paddings_tensor;
status = op.GetInputConstData("paddings", paddings_tensor);
if (status != GRAPH_SUCCESS) {
if (dim0 != UNKNOWN_DIM) {
std::vector<int64_t> output_shape(dim0, UNKNOWN_DIM);
output_desc.SetShape(Shape(output_shape));
} else {
output_desc.SetShape(Shape(UNKNOWN_SHAPE));
}
return op.UpdateOutputDesc("y", output_desc);
}
input_dim_num = paddings_tensor.GetTensorDesc().GetShape().GetDim(0);
status = WithRank(op.GetInputDesc(0), input_dim_num, input, op.GetName().c_str());
if (status == GRAPH_FAILED) {
OP_LOGE(op.GetName().c_str(), "WithRank fail");
return GRAPH_FAILED;
}
status = WithValue(dim0, input_dim_num, dim0, op.GetName().c_str());
if (status == GRAPH_FAILED) {
OP_LOGE(op.GetName().c_str(), "WithValue fail");
return GRAPH_FAILED;
}
return PadKnown(op, paddings_tensor, input_dim_num);
}
static graphStatus CalcPadGradOutDims(const Shape& input_shape, const Tensor& paddings_tensor,
std::vector<int64_t>& output_dims, const char* op_name) {
graphStatus status;
size_t input_rank = input_shape.GetDimNum();
if (output_dims.size() < input_rank) {
return GRAPH_FAILED;
}
DataType padding_type = paddings_tensor.GetTensorDesc().GetDataType();
if (padding_type == DT_INT32) {
const int32_t* paddings_data = reinterpret_cast<const int32_t*>(paddings_tensor.GetData());
CHECK(paddings_tensor.GetSize() / sizeof(int32_t) < input_rank,
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED);
for (size_t i = 0; i < input_rank; ++i) {
const int64_t pad0 = static_cast<int64_t>(paddings_data[2 * i]);
const int64_t pad1 = static_cast<int64_t>(paddings_data[(2 * i) + 1]);
if ((pad0 < 0) || (pad1 < 0)) {
OP_LOGE(op_name, "Paddings must be non-negative, pad0= %lld, pad1=%lld.", pad0, pad1);
return GRAPH_FAILED;
}
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name);
if (status != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
}
} else if (padding_type == DT_INT64) {
const int64_t* paddings_data = reinterpret_cast<const int64_t*>(paddings_tensor.GetData());
CHECK(paddings_tensor.GetSize() / sizeof(int64_t) < input_rank,
OP_LOGE(op_name, "invalid padding data."), return GRAPH_FAILED);
for (size_t i = 0; i < input_rank; ++i) {
const int64_t pad0 = paddings_data[2 * i];
const int64_t pad1 = paddings_data[(2 * i) + 1];
if ((pad0 < 0) || (pad1 < 0)) {
OP_LOGE(op_name, "Paddings must be non-negative, pad0=%lld, pad1=%lld.", pad0, pad1);
return GRAPH_FAILED;
}
status = Subtract(input_shape.GetDim(i), pad0 + pad1, output_dims[i], op_name);
if (status != GRAPH_SUCCESS) {
return GRAPH_FAILED;
}
}
} else {
OP_LOGE(op_name, "Data type invalid, should be DT_INT32 or DT_INT64");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
graphStatus PadGradShapeFn(Operator& op) {
Shape paddings;
graphStatus status = WithRank(op.GetInputDesc(1), 2, paddings, op.GetName().c_str());
if (status != GRAPH_SUCCESS) {
ShapeErrReport(1, op.GetName(), DebugString(op.GetInputDesc(1).GetShape().GetDims()), "2D");
return GRAPH_FAILED;
}
int64_t input_rank = paddings.GetDim(0);
TensorDesc output_desc = op.GetOutputDesc("y");
output_desc.SetDataType(op.GetInputDesc(0).GetDataType());
if (input_rank == UNKNOWN_DIM) {
OP_LOGE(op.GetName().c_str(), "paddings inputShape of 0 dims is unknown, set out shape unknown.");
output_desc.SetShape(Shape(UNKNOWN_SHAPE));
return op.UpdateOutputDesc("y", output_desc);
}
Shape input_shape;
if (WithRank(op.GetInputDesc(0), input_rank, input_shape, op.GetName().c_str()) != GRAPH_SUCCESS) {
ShapeErrReport(0, op.GetName(), DebugString(op.GetInputDesc(0).GetShape().GetDims()), ConcatString(input_rank));
return GRAPH_FAILED;
}
Shape check_shape({input_rank, 2});
if (Merge(paddings, check_shape, paddings, op.GetName().c_str())) {
string err_msg = ConcatString("merge 1th input shape", DebugString(paddings.GetDims()), " and shape",
DebugString(check_shape.GetDims()), " failed");
InferShapeOtherErrReport(op.GetName(), err_msg);
OP_LOGE(op.GetName().c_str(), "Input dimension mismatch, inputRank=%lld.", input_rank);
return GRAPH_FAILED;
}
Tensor paddings_tensor;
if (op.GetInputConstData("paddings", paddings_tensor) != GRAPH_SUCCESS) {
std::vector<int64_t> unknow_dim_vec(input_rank, UNKNOWN_DIM);
OP_LOGE(op.GetName().c_str(), "Get paddings input tensor fail, set outPut shape unknown.");
output_desc.SetShape(Shape(unknow_dim_vec));
return op.UpdateOutputDesc("y", output_desc);
}
std::vector<int64_t> output_dims(input_rank);
auto result = CalcPadGradOutDims(input_shape, paddings_tensor, output_dims, op.GetName().c_str());
if (result != GRAPH_SUCCESS) {
string err_msg = ConcatString("calculate out dims failed,", "please check the validity of input and attribute");
InferShapeOtherErrReport(op.GetName(), err_msg);
OP_LOGE(op.GetName().c_str(), "Calculation PadGrad out dimensions failed.");
return GRAPH_FAILED;
}
output_desc.SetShape(Shape(output_dims));
return op.UpdateOutputDesc("y", output_desc);
}
} // namespace ge

+ 42
- 0
tests/st/framework/stub_op_proto/util/array_ops_shape_fns.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file array_ops_shape_fns.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_
#include "graph/operator.h"
namespace ge {
/* *
* infer pad op shape
* @param op Operator which need to infershape
* @return status whether infershape success
*/
graphStatus PadShapeFn(Operator& op);
/* *
* infer pad grad op shape
* @param op Operator which need to infershape
* @return status whether infershape success
*/
graphStatus PadGradShapeFn(Operator& op);
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ARRAY_OPS_SHAPE_FNS_H_

+ 195
- 0
tests/st/framework/stub_op_proto/util/axis_util.cc View File

@@ -0,0 +1,195 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file axis_util.cpp
* \brief get the axis value
*/
#include "axis_util.h"
#include "framework/omg/omg_inner_types.h"
#include "framework/common/types.h"
namespace ge {
AxisUtil::AxisUtil() {
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)},
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)},
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)},
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)},
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)},
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}};
}
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) {
if (divisor == 0) {
return 0;
} else {
return (dividend + divisor - 1) / divisor;
}
}
bool AxisUtil::GetAxisValueByOriginFormat(const Format& format, const vector<int64_t>& dimVec, const uint32_t& c0,
vector<int64_t>& axisValue, vector<int64_t>& ndValue) {
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
LOG_INFO("Can not get axis value of old format %u!", format);
return false;
}
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second;
CHECK_NOTNULL(getAxisFunc);
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue);
}
bool AxisUtil::HasAxisValueFunc(const Format& format) {
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
LOG_INFO("Can not get axis value of format %u!", format);
return false;
}
return true;
}
bool AxisUtil::CheckParams(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue,
vector<int64_t>& ndValue) {
ndValue = originalDimVec;
auto dimSize = originalDimVec.size();
if (dimSize < ge::DIM_DEFAULT_SIZE) {
/* Before this funcion, we should call function PadDimensionTo4. */
LOG_INFO("Dimension size %zu is invalid.", dimSize);
return false;
}
if (c0 == 0) {
LOG_ERROR("[ERROR]c0 is zero!");
return false;
}
return true;
}
bool AxisUtil::GetAxisValueByND(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue,
vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
ndValue = originalDimVec;
/* To differentiate the input datatype of int8 and others */
axisValue[AXIS_C0] = c0;
if (originalDimVec.size() == NCHW_DIMENSION_NUM) {
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
}
return true;
}
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue,
vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NCHW to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue,
vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t>& originalDimVec, const uint32_t& c0,
vector<int64_t>& axisValue, vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"),
return false);
auto dimSize = originalDimVec.size();
if (dimSize == ge::DIM_DEFAULT_SIZE + 1) {
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1];
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0];
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0];
} else {
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_C0] = c0;
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
}
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
return true;
}
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t>& originalDimVec, const uint32_t& c0, vector<int64_t>& axisValue,
vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t>& originalDimVec, const uint32_t& c0,
vector<int64_t>& axisValue, vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), LOG_INFO("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, LOG_ERROR("[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0;
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W];
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1];
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co];
return true;
}
}; // namespace ge

+ 144
- 0
tests/st/framework/stub_op_proto/util/axis_util.h View File

@@ -0,0 +1,144 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file axis_util.h
* \brief get the axis value
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_
#include <memory.h>
#include <functional>
#include <vector>
#include "framework/omg/omg_inner_types.h"
#include "operator.h"
#include "graph/operator_reg.h"
#include "op_log.h"
#define LOG_ERROR(format, args...) printf(format, ##args)
#define LOG_INFO(format, args...) printf(format, ##args)
namespace ge {
const uint32_t NCHW_DIMENSION_NUM = 4;
const int32_t AXIS_NCHW_DIM_N = 0;
const int32_t AXIS_NCHW_DIM_C = 1;
const int32_t AXIS_NCHW_DIM_H = 2;
const int32_t AXIS_NCHW_DIM_W = 3;
const int32_t AXIS_NHWC_DIM_N = 0;
const int32_t AXIS_NHWC_DIM_H = 1;
const int32_t AXIS_NHWC_DIM_W = 2;
const int32_t AXIS_NHWC_DIM_C = 3;
const int32_t AXIS_NC1HWC0_DIM_N = 0;
const int32_t AXIS_NC1HWC0_DIM_C1 = 1;
const int32_t AXIS_NC1HWC0_DIM_C0 = 4;
const int32_t AXIS_NC1HWC0_DIM_H = 2;
const int32_t AXIS_NC1HWC0_DIM_W = 3;
const int32_t AXIS_HWCN_DIM_H = 0;
const int32_t AXIS_HWCN_DIM_W = 1;
const int32_t AXIS_HWCN_DIM_C = 2;
const int32_t AXIS_HWCN_DIM_N = 3;
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0;
const int32_t AXIS_C1HWNCoC0_DIM_H = 1;
const int32_t AXIS_C1HWNCoC0_DIM_W = 2;
const int32_t AXIS_C1HWNCoC0_DIM_N = 3;
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4;
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5;
#define CHECK_NOTNULL(val) \
do { \
if ((val) == nullptr) { \
LOG_ERROR("[ERROR]Parameter[%s] must not be null.", #val); \
return false; \
} \
} while (0)
#define CHECK(cond, log_func, return_expr) \
do { \
if (cond) { \
log_func; \
return_expr; \
} \
} while (0)
enum AxisValueType {
AXIS_N = 0,
AXIS_C = 1,
AXIS_H = 2,
AXIS_W = 3,
AXIS_C1 = 4,
AXIS_C0 = 5,
AXIS_Co = 6,
AXIS_D = 7,
AXIS_BOTTOM = 8
};
int64_t DivisionCeiling(int64_t dividend, int64_t divisor);
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */
/* The first parameter is old shape's dimension,
* second is c0 and third is axis value. */
using GetAxisValueInfoByFormat =
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>;
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>;
class AxisUtil {
public:
AxisUtil();
~AxisUtil(){};
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
bool HasAxisValueFunc(const ge::Format& format);
private:
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
/* map of GetAxisValueInfoByFormat, get axis value by different original
* formats. */
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap;
};
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_AXIS_UTIL_H_

+ 1038
- 0
tests/st/framework/stub_op_proto/util/common_shape_fns.cc
File diff suppressed because it is too large
View File


+ 417
- 0
tests/st/framework/stub_op_proto/util/common_shape_fns.h View File

@@ -0,0 +1,417 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file common_shape_fns.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_
#include <string>
#include <vector>
#include "graph/tensor.h"
#include "graph/operator.h"
#include "graph/op_desc.h"
#include "graph/ge_tensor.h"
#include "error_code.h"
namespace ge {
/**
* Check whether Shape's rank is at least rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtLeast(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name);
/**
* Check whether Shape's rank is at least rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtLeast(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape);
/**
* Check whether Shape's rank is equal to rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name);
/**
* Check whether Shape's rank is equal to rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape);
/**
* Check whether Shape's rank is equal to rank
* @param tensor Input tensor
* @param rank expect val of Shape
* @param out Output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRank(const GeTensorDescPtr& tensorDesc, int64_t rank, Shape& out_shape);
/**
* Check whether dim is equal to value
* @param dim Input dim
* @param value expect val of dim
* @param out Output dim
* @return status whether Dim is equal to value
*/
graphStatus WithValue(int64_t dim, int64_t value, int64_t& out, const char* op_name);
/**
* Merge two dims of Shape
* @param dim0 first dim val
* @param dim1 second dim val
* @param out merged dim val
* @return status whether this operation success
*/
graphStatus Merge(int64_t dim1, int64_t dim2, int64_t& out);
/**
* Merge two shapes
* @param s0 first shape val
* @param s1 second shape val
* @param out merged shape val
* @return status whether this operation success
*/
graphStatus Merge(const Shape& s0, const Shape& s1, Shape& out, const char* op_name);
/**
* Merge two shapes
* @param s0 first Geshape val
* @param s1 second Geshape val
* @param out merged Geshape val
* @return status whether this operation success
*/
graphStatus Merge(const GeShape& s0, const GeShape& s1, GeShape& out, const char* op_name);
/**
* Replace one dim in a given shape
* @param s original shape
* @param dim_index_in dim index
* @param new_dim new dim value
* @param out new shape
* @return status whether this operation success
*/
graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name);
/**
* Replace one dim in a given shape
* @param s original shape
* @param dim_index_in dim index
* @param new_dim new dim value
* @param out new shape
* @return status whether this operation success
*/
graphStatus ReplaceDim(const GeShape& s, int64_t dim_index_in, int64_t new_dim, GeShape& out, const char* op_name);
/**
* Check if it satisfies 0 <= index < limit
* @param index first input
* @param limit second input
* @return status whether this operation success
*/
template <typename Ta, typename Tb>
bool FastBoundsCheck(const Ta index, const Tb limit);
/**
* Add two dims
* @param dim0 first dim val
* @param dim1 second dim val
* @param out sum dim val
* @return status whether this operation success
*/
graphStatus Add(int64_t dim1, int64_t dim2, int64_t& out);
/**
* Subtract two dims
* @param dim0 first dim val
* @param dim1 second dim val
* @param out Subtract dim val
* @return status whether this operation success
*/
graphStatus Subtract(int64_t dim1, int64_t dim2, int64_t& out, const char* op_name);
/**
* Get SubShape according to start end index and step size stride
* @param s input Shape
* @param start sub start index
* @param end sub end index
* @param stride sub step size
* @param out sub shape output
* @return status whether this operation success
*/
graphStatus SubShape(const Shape& s, int64_t start, int64_t end, int64_t stride, Shape& out, const char* op_name);
/**
* Get SubShape according to start end index and step size stride
* @param s input Shape
* @param start sub start index
* @param end sub end index
* @param stride sub step size
* @param out sub shape output
* @return status whether this operation success
*/
graphStatus SubShape(const GeShape& s, size_t start, size_t end, size_t stride, GeShape& out);
/**
* Get SubShape according to start end index and step size stride
* @param s input Shape
* @param start sub start index
* @param end sub end index
* @param stride sub step size
* @param out sub shape output
* @return status whether this operation success
*/
graphStatus SubShape(const GeShape& s, int64_t start, int64_t end, int64_t stride, GeShape& out, const char* op_name);
/**
* Concatenate two shape
* @param s1 first shape
* @param s2 second shape
* @param out concatenated shape
* @return status whether this operation success
*/
graphStatus Concatenate(const Shape& s1, const Shape& s2, Shape& out);
/**
* Concatenate two shape
* @param s1 first shape
* @param s2 second shape
* @param out concatenated shape
* @return status whether this operation success
*/
graphStatus Concatenate(const GeShape& s1, const GeShape& s2, GeShape& out);
/**
* Gen matrix shape according d1 and d2
* @param dim1 first dim val
* @param dim2 first dim val
* @param out matrix shape
* @return status whether this operation success
*/
graphStatus Matrix(int64_t dim1, int64_t dim2, Shape& out);
/**
* Gen vector shape according d
* @param dim dim val
* @param out vector shape
* @return status whether this operation success
*/
graphStatus Vector(int64_t dim, Shape& out);
/**
* Make shape from shape tensor
* @param tensor shape tensor
* @param out shape
* @return status whether this operation success
*/
graphStatus MakeShapeFromShapeTensor(const Tensor& tensor, Shape& out, const char* op_name);
/**
* Make shape from shape tensor
* @param op Operator
* @param dst_name const string &
* @param out GeShape
* @param op_name const char *
* @return status whether this operation success
*/
graphStatus MakeShapeFromShapeTensor(Operator& op, const string& dst_name, GeShape& out, const char* op_name);
/**
* Make dim from scalar tensor
* @param tensor shape tensor
* @param out shape
* @return status whether this operation success
*/
graphStatus MakeDimForScalarInput(const Tensor& tensor, int64_t& out, const char* op_name);
/**
* Check whether Shape's rank is at most rank
* @param tensor input tensor
* @param rank expect val of Shape
* @param out output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtMost(const TensorDesc& tensor, int64_t rank, Shape& out, const char* op_name);
/**
* Check whether Shape's rank is at most rank
* @param tensor input tensor
* @param rank expect val of Shape
* @param out output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus WithRankAtMost(const GeTensorDescPtr& tensorDesc, int64_t rank, GeShape& out_shape);
/**
* make a empty dim shape
* @param out output Shape
* @return status whether Shape's condition Satisfied
*/
graphStatus Scalar(Shape& out);
/**
* set input_name shape to output_name shape
* @param op Operator which need to infershape
* @param input_name input name of Operator
* @param output_name ouput name of Operator
* @return status whether infershape success
*/
graphStatus UnchangedShape(Operator& op, const string input_name, const string output_name);
/**
* Devide dim
* @param dividend
* @param divisor
* @param evenlyDivisible if to be divisible
* @param out dims
* @return status whether this operation success
*/
graphStatus Divide(const int64_t dividend, const int64_t divisor, const bool evenlyDivisible, int64_t& out,
const char* op_name);
/**
* check shape fully defined or not
* @param shape Shape is checked
* @return whether shape is fully defined
*/
bool ShapeFullDefined(const Shape& shape);
/**
* check shape fully defined or not
* @param shape Shape is checked
* @return whether shape is fully defined
*/
bool ShapeFullyDefined(const GeShape& shape);
/**
* check shape known or not
* @param shape Shape is checked
* @return whether rank is known
*/
bool RankKnown(const Shape& shape);
/**
* check ge_shape known or not
* @param shape GeShape is checked
* @return whether rank is known
*/
bool RankKnown(const GeShape& shape);
/**
* make a unknown shape with rank
* @return unknown shape
*/
Shape UnknownShapeOfRank(int64_t rank);
/**
* check dim value known or not
* @param shape which Shape need check dim value
* @param dimIndex the index of dim
* @return whether dim value is known
*/
bool ValueKnown(const Shape& shape, const size_t& dim_index);
/**
* Validates the 3 component tensors of a sparse tensor
* have the proper shapes.
* @param sparse indices shape
* @param sparse values shape
* @param sparse shape
* @return status whether this operation success
*/
graphStatus ValidateSparseTensor(const TensorDesc& indices, const TensorDesc& values, const TensorDesc& shape,
const char* op_name);
/**
* DecodeWavShapeFn, infereshape funtion of DecodeWav op
* @param op Operator
* @return status whether Shape's condition Satisfied
*/
graphStatus DecodeWavShapeFn(Operator& op);
/**
* EncodeWavShapeFn, infereshape funtion of EncodeWav op
* @param op Operator
* @return status whether Shape's condition Satisfied
*/
graphStatus EncodeWavShapeFn(Operator& op);
/**
* EncodeWavShapeFn, infereshape funtion of EncodeWav op
* @param op Operator
* @return status whether Shape's condition Satisfied
*/
graphStatus EncodeWavShapeFn(Operator& op);
/**
* Infereshape funtion of SparseSegmentReduction op
* @param op Operator
* @return status whether Shape's condition Satisfied
*/
graphStatus SparseSegmentReductionShapeFn(Operator& op);
/**
* Infereshape funtion of SparseSegmentReductionGrad op
* @param op Operator
* @return status whether Shape's condition Satisfied
*/
graphStatus SparseSegmentReductionGradShapeFn(Operator& op);
/**
* Validates variable resource handle
* @param op Operator
* @param shape_and_type ShapeAndType vector
* @return status whether this operation success
*/
graphStatus ValidateVariableResourceHandle(Operator& op, std::vector<ShapeAndType>& shape_and_type);
/**
* Fill op_desc with input shape
* @param op_desc Operator desc ptr
* @param shape input tensor shape
* @param shape input tensor datatype
*/
void FillOpDesc(GeTensorDescPtr& op_desc, const GeShape& shape, const DataType& data_type = DT_FLOAT);
/**
* InferShapeErrorReport info
* @param op_name Operator name
* @param op_type Operator type
* @param value Operator value
* @param reason error reason
*/
void InferShapeErrorReport(const std::string& op_name, const std::string& op_type,
const std::string& value, const std::string& reason);
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_COMMON_SHAPE_FNS_H_

+ 60
- 0
tests/st/framework/stub_op_proto/util/error_code.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file error_code.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_
namespace ge {
// error code for report purpose.
// 30000~34999 for aicpu engine error
// and 35000~39999 for infershape error of aicpu op
enum ViewErrorCode {
INVALID_INFER_SHAPE = 14001,
INVALID_INPUT_SHAPE = 35000,
INVALID_ATTR_VALUE = 35001,
INVALID_ATTR_SIZE = 35002,
OTHER_ERROR = 35003,
INVALID_CONV_ATTR_VALUE = 50029,
INVALID_CONV_SET_ATTR = 50057,
INVALID_CONV_SHAPE = 50058,
INVALID_MISS_INPUT = 70001,
INVALID_INPUT_FORMAT = 70002,
INVALID_INPUT_DTYPE = 70003,
INVALID_INPUT_TYPE = 70004,
INVALID_GET_ATTR = 70005,
INVALID_SET_ATTR = 70006,
INVALID_OPS_ATTR_VALUE = 70007,
FAILED_UPDATE_OP = 70008,
INVALID_SHAPE = 70009,
INVALID_SHAPE_SIZE = 70010,
INVALID_SHAPE_DIM = 70011,
INVALID_BROADCAST_SHAPE = 70012,
INVALID_TWO_INPUT_DTYPE = 70013,
INVALID_AIPP_ERROR = 70014,
INVALID_ONE_INPUT_SHAPE = 70015,
INVALID_TWO_INPUT_SHAPE = 70016,
INVALID_ONE_OUTPUT_SHAPE = 70017,
FAILED_GET_COMPILIE_PARAMS = 70018,
};
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_CODE_H_

+ 318
- 0
tests/st/framework/stub_op_proto/util/error_util.cc View File

@@ -0,0 +1,318 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file error_util.cpp
* \brief
*/
#include <map>
#include "common/util/error_manager/error_manager.h"
#include "error_util.h"
#include "error_code.h"
#include "op_log.h"
using namespace std;
using namespace ge;
namespace ge {
inline static std::string GetViewErrorCodeStr(ge::ViewErrorCode errCode) {
return "E" + std::to_string(errCode);
}
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape,
const std::string& correct_shape) {
map<string, string> err_map;
err_map["index"] = std::to_string(index);
err_map["opname"] = opname;
err_map["wrong_shape"] = wrong_shape;
err_map["correct_shape"] = correct_shape;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_SHAPE);
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value,
const std::string& correct_value) {
map<string, string> err_map;
err_map["attrname"] = attrName;
err_map["opname"] = opname;
err_map["wrong_value"] = wrong_value;
err_map["correct_value"] = correct_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_VALUE);
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size,
const std::string& correct_size) {
map<string, string> err_map;
err_map["attrname"] = attrName;
err_map["opname"] = opname;
err_map["wrong_size"] = wrong_size;
err_map["correct_size"] = correct_size;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ATTR_SIZE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg) {
map<string, string> err_map;
err_map["opname"] = opname;
err_map["err_msg"] = err_msg;
string report_error_code = GetViewErrorCodeStr(ViewErrorCode::OTHER_ERROR);
(void)ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_MISS_INPUT);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_format_list, const std::string& data_format) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["expected_format_list"] = expected_format_list;
err_map["format"] = data_format;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_FORMAT);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_data_type_list, const std::string& data_type) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["expected_data_type_list"] = expected_data_type_list;
err_map["data_type"] = data_type;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_DTYPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type,
const std::string& actual_type) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["param_type"] = param_type;
err_map["actual_type"] = actual_type;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INPUT_TYPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_GET_ATTR);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SET_ATTR);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value,
const std::string& input_value) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["excepted_value"] = excepted_value;
err_map["input_value"] = input_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_OPS_ATTR_VALUE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_UPDATE_OP);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name,
const std::string& param_value) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["rule_desc"] = rule_desc;
err_map["param_name"] = param_name;
err_map["param_value"] = param_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& error_detail) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["error_detail"] = error_detail;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_INPUT_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1,
const std::string& param_name2, const std::string& error_detail) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name1"] = param_name1;
err_map["param_name2"] = param_name2;
err_map["error_detail"] = error_detail;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& error_detail) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["error_detail"] = error_detail;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_ONE_OUTPUT_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::FAILED_GET_COMPILIE_PARAMS);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value,
const std::string& real_value) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["input_name"] = input_name;
err_map["max_value"] = max_value;
err_map["real_value"] = real_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_SIZE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value,
const std::string& min_value, const std::string& real_value) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["max_value"] = max_value;
err_map["min_value"] = min_value;
err_map["real_value"] = real_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_SHAPE_DIM);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name,
const std::string& input2_name, const std::string& input1_shape,
const std::string& input2_shape) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["input1_name"] = input1_name;
err_map["input2_name"] = input2_name;
err_map["input1_shape"] = input1_shape;
err_map["input2_shape"] = input2_shape;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_BROADCAST_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_dtype_list, const std::string& dtype) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["expected_dtype_list"] = expected_dtype_list;
err_map["dtype"] = dtype;
std::string report_error_code = "E50034";
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name,
const std::string& input2_name, const std::string& input1_dtype,
const std::string& input2_dtype) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["input1_name"] = input1_name;
err_map["input2_name"] = input2_name;
err_map["input1_dtype"] = input1_dtype;
err_map["input2_dtype"] = input2_dtype;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_TWO_INPUT_DTYPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H,
const std::string& data_W) {
map<string, string> err_map;
err_map["aipp_output_H"] = aipp_output_H;
err_map["aipp_output_W"] = aipp_output_W;
err_map["data_H"] = data_H;
err_map["data_W"] = data_W;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_AIPP_ERROR);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value,
const std::string& input_value) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param_name"] = param_name;
err_map["expected_value"] = expected_value;
err_map["input_value"] = input_value;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_ATTR_VALUE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name,
const std::string& param2_name) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["param1_name"] = param1_name;
err_map["param2_name"] = param2_name;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SET_ATTR);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description) {
map<string, string> err_map;
err_map["op_name"] = op_name;
err_map["description"] = description;
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_CONV_SHAPE);
ErrorManager::GetInstance().ReportErrMessage(report_error_code, err_map);
}
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value,
const std::string& reason) {
std::string report_error_code = GetViewErrorCodeStr(ViewErrorCode::INVALID_INFER_SHAPE);
ErrorManager::GetInstance().ATCReportErrMessage(report_error_code, {"opname", "optype", "value", "reason"},
{op_name, op_type, value, reason});
}
void CommonRuntimeErrLog(const std::string& opname, const std::string& description){
map<string, string> err_map;
err_map["op_name"] = opname;
err_map["description"] = description;
OP_LOGE(opname.c_str(), description);
(void)ErrorManager::GetInstance().ReportErrMessage("E50058", err_map);
}
} // namespace ge

+ 184
- 0
tests/st/framework/stub_op_proto/util/error_util.h View File

@@ -0,0 +1,184 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file error_util.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_
#include <sstream>
#include <string>
#include <vector>
#include "operator.h"
namespace ge {
/*
* get debug string of vector
* param[in] v vector
* return vector's debug string
*/
template <typename T>
std::string DebugString(const std::vector<T>& v) {
std::ostringstream oss;
oss << "[";
if (v.size() > 0) {
for (size_t i = 0; i < v.size() - 1; ++i) {
oss << v[i] << ", ";
}
oss << v[v.size() - 1];
}
oss << "]";
return oss.str();
}
/*
* str cat util function
* param[in] params need concat to string
* return concatted string
*/
template <typename T>
std::string ConcatString(T arg) {
std::ostringstream oss;
oss << arg;
return oss.str();
}
template <typename T, typename... Ts>
std::string ConcatString(T arg, Ts... arg_left) {
std::ostringstream oss;
oss << arg;
oss << ConcatString(arg_left...);
return oss.str();
}
/*
* report input shape error of infer shape
* param[in] index the index of input
* param[in] opname op name
* param[in] wrong_shape wrong input shape
* param[in] correct_shape correct input shape
* return void
*/
void ShapeErrReport(uint32_t index, const std::string& opname, const std::string& wrong_shape,
const std::string& correct_shape);
/*
* report attr value error of infer shape
* param[in] attrname the attr name
* param[in] opname op name
* param[in] wrong_value wrong attr value
* param[in] correct_value correct attr value
* return void
*/
void AttrValueErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_value,
const std::string& correct_value);
/*
* report attr size error of infer shape
* param[in] attrname the attr name
* param[in] opname op name
* param[in] wrong_size wrong attr size
* param[in] correct_size correct attr size
* return void
*/
void AttrSizeErrReport(const std::string& attrName, const std::string& opname, const std::string& wrong_size,
const std::string& correct_size);
/*
* report common error of infer shape
* param[in] opname op name
* param[in] err_msg error message
* return void
*/
void InferShapeOtherErrReport(const std::string& opname, const std::string& err_msg);
void OpsMissInputErrReport(const std::string& op_name, const std::string& param_name);
void OpsInputFormatErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_format_list, const std::string& data_format);
void OpsInputDtypeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_data_type_list, const std::string& data_type);
void OpsInputTypeErrReport(const std::string& op_name, const std::string& param_name, const std::string& param_type,
const std::string& actual_type);
void OpsGetAttrErrReport(const std::string& op_name, const std::string& param_name);
void OpsSetAttrErrReport(const std::string& op_name, const std::string& param_name);
void OpsAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& excepted_value,
const std::string& input_value);
void OpsOPUpdateErrReport(const std::string& op_name, const std::string& param_name);
void OpsInputShapeErrReport(const std::string& op_name, const std::string& rule_desc, const std::string& param_name,
const std::string& param_value);
void OpsOneInputShapeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& error_detail);
void OpsTwoInputShapeErrReport(const std::string& op_name, const std::string& param_name1,
const std::string& param_name2, const std::string& error_detail);
void OpsOneOutputShapeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& error_detail);
void OpsGetCompileParamsErrReport(const std::string& op_name, const std::string& param_name);
void OpsInputShapeSizeErrReport(const std::string& op_name, const std::string& input_name, const std::string& max_value,
const std::string& real_value);
void OpsInputShapeDimErrReport(const std::string& op_name, const std::string& param_name, const std::string& max_value,
const std::string& min_value, const std::string& real_value);
void OpsInputShapeBroadcastErrReport(const std::string& op_name, const std::string& input1_name,
const std::string& input2_name, const std::string& input1_shape,
const std::string& input2_shape);
void TbeInputDataTypeErrReport(const std::string& op_name, const std::string& param_name,
const std::string& expected_dtype_list, const std::string& dtype);
void OpsTwoInputDtypeErrReport(const std::string& op_name, const std::string& input1_name,
const std::string& input2_name, const std::string& input1_dtype,
const std::string& input2_dtype);
void OpsAippErrReport(const std::string& aipp_output_H, const std::string& aipp_output_W, const std::string& data_H,
const std::string& data_W);
void OpsConvAttrValueErrReport(const std::string& op_name, const std::string& param_name, const std::string& expected_value,
const std::string& input_value);
void OpsConvSetAttrErrReport(const std::string& op_name, const std::string& param1_name,
const std::string& param2_name);
void OpsConvShapeErrReport(const std::string& op_name, const std::string& description);
void GeInfershapeErrReport(const std::string& op_name, const std::string& op_type, const std::string& value,
const std::string& reason);
/*
* log common runtime error
* param[in] opname op name
* param[in] error description
* return void
*/
void CommonRuntimeErrLog(const std::string& opname, const std::string& description);
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_ERROR_UTIL_H_

+ 73
- 0
tests/st/framework/stub_op_proto/util/op_common_util.h View File

@@ -0,0 +1,73 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file op_common_util.h
* \brief common util for op, in this file only original type or class in C++ allowed
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_
#include <set>
#include <string>
#include <vector>
#include <iostream>
#include <sstream>
template <typename T1, typename T2>
std::ostream& operator<<(std::ostream& os, const std::pair<T1, T2>& values) {
os << "[" << values.first << ", " << values.second << "]";
return os;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
os << "[";
for (const auto& item : values) {
os << item << ", ";
}
os << "]";
return os;
}
namespace ops {
template<typename T>
std::string to_string(const std::vector<T> &items) {
std::ostringstream oss;
oss << "[";
for (const auto &item: items) {
oss << item << ", ";
}
oss << "]";
return oss.str();
}
template<typename T>
std::string to_string(const std::set<T> &items) {
std::ostringstream oss;
oss << "[";
for (const auto &item: items) {
oss << item << ", ";
}
oss << "]";
return oss.str();
}
} // namespace ops
#endif //OPS_BUILT_IN_OP_PROTO_UTIL_OP_COMMON_UTIL_H_

+ 89
- 0
tests/st/framework/stub_op_proto/util/op_log.h View File

@@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file op_log.h
* \brief
*/
#ifndef GE_OP_LOG_H
#define GE_OP_LOG_H
#if !defined( __ANDROID__) && !defined(ANDROID)
#include "toolchain/slog.h"
#else
#include <utils/Log.h>
#endif
#define OPPROTO_SUBMOD_NAME "OP_PROTO"
#if !defined( __ANDROID__) && !defined(ANDROID)
#define OP_LOGI(opname, ...) D_OP_LOGI(opname, __VA_ARGS__)
#define OP_LOGW(opname, ...) D_OP_LOGW(opname, __VA_ARGS__)
#define OP_LOGE(opname, ...) D_OP_LOGE(opname, __VA_ARGS__)
#define OP_LOGD(opname, ...) D_OP_LOGD(opname, __VA_ARGS__)
#define GE_OP_LOGI(opname, ...) GE_D_OP_LOGI(opname, __VA_ARGS__)
#define GE_OP_LOGW(opname, ...) GE_D_OP_LOGW(opname, __VA_ARGS__)
#define GE_OP_LOGE(opname, ...) GE_D_OP_LOGE(opname, __VA_ARGS__)
#define GE_OP_LOGD(opname, ...) GE_D_OP_LOGD(opname, __VA_ARGS__)
#define FUSION_PASS_LOGI(...) D_FUSION_PASS_LOGI(__VA_ARGS__)
#define FUSION_PASS_LOGW(...) D_FUSION_PASS_LOGW(__VA_ARGS__)
#define FUSION_PASS_LOGE(...) D_FUSION_PASS_LOGE(__VA_ARGS__)
#define FUSION_PASS_LOGD(...) D_FUSION_PASS_LOGD(__VA_ARGS__)
#else
#define OP_LOGI(opname, ...)
#define OP_LOGW(opname, ...)
#define OP_LOGE(opname, ...)
#define OP_LOGD(opname, ...)
#define FUSION_PASS_LOGI(...)
#define FUSION_PASS_LOGW(...)
#define FUSION_PASS_LOGE(...)
#define FUSION_PASS_LOGD(...)
#endif
#if !defined( __ANDROID__) && !defined(ANDROID)
#define D_OP_LOGI(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define D_OP_LOGW(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define D_OP_LOGE(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define D_OP_LOGD(opname, fmt, ...) DlogSub(TBE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define GE_D_OP_LOGI(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define GE_D_OP_LOGW(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define GE_D_OP_LOGE(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define GE_D_OP_LOGD(opname, fmt, ...) DlogSub(GE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d OpName:[%s] "#fmt, __FUNCTION__, __LINE__, opname, ##__VA_ARGS__)
#define D_FUSION_PASS_LOGI(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_INFO, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define D_FUSION_PASS_LOGW(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_WARN, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define D_FUSION_PASS_LOGE(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_ERROR, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define D_FUSION_PASS_LOGD(fmt, ...) DlogSub(FE, OPPROTO_SUBMOD_NAME, DLOG_DEBUG, " %s:%d "#fmt, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#else
#define D_OP_LOGI(opname, fmt, ...)
#define D_OP_LOGW(opname, fmt, ...)
#define D_OP_LOGE(opname, fmt, ...)
#define D_OP_LOGD(opname, fmt, ...)
#define D_FUSION_PASS_LOGI(fmt, ...)
#define D_FUSION_PASS_LOGW(fmt, ...)
#define D_FUSION_PASS_LOGE(fmt, ...)
#define D_FUSION_PASS_LOGD(fmt, ...)
#endif
#define OP_CHECK(condition, log_func, do_expr) \
static_assert(std::is_same<bool, std::decay<decltype(condition)>::type>::value, "condition should be bool"); \
do { \
if (condition) { \
log_func; \
do_expr; \
} \
} while (0)
#endif //GE_OP_LOG_H

+ 258
- 0
tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.cc View File

@@ -0,0 +1,258 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file transfer_shape_according_to_format.cpp
* \brief set shape according to original format and current format
*/
#include "transfer_shape_according_to_format.h"
#include "framework/omg/omg_inner_types.h"
namespace ge {
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) {
getNewShapeFuncMap = {
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)},
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)},
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)},
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)},
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)},
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)},
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}};
mapOfDtypeAndC0 = {
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32},
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16},
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16},
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}};
}
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_C]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newShape = ge::GeShape(newDimVec);
return true;
}
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newDimVec.push_back(axisValue[AXIS_C]);
newShape = ge::GeShape(newDimVec);
return true;
}
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true);
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_C1]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newDimVec.push_back(axisValue[AXIS_C0]);
newShape = ge::GeShape(newDimVec);
} else {
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true);
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_C]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newShape = ge::GeShape(newDimVec);
}
return true;
}
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), LOG_INFO("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
if (ndValue.size() == SIZE_OF_CN) {
CHECK(axisValue.size() <= AXIS_C0, LOG_INFO("AxisValue is not correct!"), return true);
auto sizeOfOriginalVec = ndValue.size();
std::vector<int64_t> newDimVec = ndValue;
/* sizeOfOriginalVec - 1 mean the last value of original vec
* sizeOfOriginalVec - 2 mean the second last value of original vec */
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16);
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]);
newDimVec.push_back(SHAPE_NUMBER_16);
newDimVec.push_back(axisValue[AXIS_C0]);
newShape = ge::GeShape(newDimVec);
} else {
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
CHECK(axisValue.size() <= AXIS_C1, LOG_INFO("AxisValue is not correct!"), return true);
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W];
newDimVec.push_back(hwc1);
newDimVec.push_back(DivisionCeiling(axisValue[AXIS_N], NI));
newDimVec.push_back(NI);
newDimVec.push_back(axisValue[AXIS_C0]);
newShape = ge::GeShape(newDimVec);
} else {
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true);
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_C]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newShape = ge::GeShape(newDimVec);
}
}
return true;
}
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.size() <= AXIS_W, LOG_INFO("AxisValue is not correct!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newDimVec.push_back(axisValue[AXIS_C]);
newDimVec.push_back(axisValue[AXIS_N]);
newShape = ge::GeShape(newDimVec);
return true;
}
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.size() <= AXIS_Co, LOG_INFO("AxisValue is not correct!"), return true);
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec;
newDimVec.push_back(axisValue[AXIS_C1]);
newDimVec.push_back(axisValue[AXIS_H]);
newDimVec.push_back(axisValue[AXIS_W]);
newDimVec.push_back(axisValue[AXIS_N]);
newDimVec.push_back(axisValue[AXIS_Co]);
newDimVec.push_back(axisValue[AXIS_C0]);
newShape = ge::GeShape(newDimVec);
return true;
}
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(ndValue.empty(), LOG_INFO("ndValue is empty!"), return true);
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0,
LOG_INFO("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true);
uint32_t sizeOfOriginalVec = ndValue.size();
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) {
LOG_INFO("ndValue's dim num is less than 2!");
return true;
}
/* axisValue is initialized as a size 6 vector. */
std::vector<int64_t> newDimVec = ndValue;
/* sizeOfOriginalVec - 1 mean the last value of original vec
* sizeOfOriginalVec - 2 mean the second last value of original vec */
newDimVec[sizeOfOriginalVec - MINUS_VALUE_ONE] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16);
newDimVec[sizeOfOriginalVec - MINUS_VALUE_TWO] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]);
newDimVec.push_back(SHAPE_NUMBER_16);
newDimVec.push_back(axisValue[AXIS_C0]);
newShape = ge::GeShape(newDimVec);
return true;
}
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) {
/* The default new shape is old shape */
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape;
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) {
LOG_ERROR("Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat);
return false;
}
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) {
LOG_ERROR("currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType);
return false;
}
AxisUtil* axisutil_object = new AxisUtil();
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) {
delete axisutil_object;
return true;
}
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat);
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) {
LOG_INFO("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat);
delete axisutil_object;
return true;
}
LOG_INFO("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat);
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second;
CHECK_NOTNULL(getNewShapeFunc);
std::vector<int64_t> axisValue;
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) {
axisValue.push_back(1);
}
std::vector<int64_t> ndValue;
uint32_t c0;
if (mapOfDtypeAndC0.empty()) {
c0 = SHAPE_NUMBER_16;
} else {
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType);
if (iterGetC0 == mapOfDtypeAndC0.end()) {
LOG_ERROR("Dtype is not support.");
delete axisutil_object;
return true;
}
c0 = iterGetC0->second;
}
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) {
c0 = SHAPE_DIM_VALUE_C04;
}
bool status = axisutil_object->GetAxisValueByOriginFormat(
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape.GetDims(), c0, axisValue, ndValue);
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) {
delete axisutil_object;
return true;
}
delete axisutil_object;
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue);
if (c != nullptr) {
*c = axisValue[AXIS_C];
}
return true;
}
}; // namespace ge

+ 129
- 0
tests/st/framework/stub_op_proto/util/transfer_shape_according_to_format.h View File

@@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file transfer_shape_according_to_format.h
* \brief set shape according to original format and current format
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#include "axis_util.h"
#include <memory.h>
#include <functional>
#include <vector>
#include "framework/omg/omg_inner_types.h"
#include "operator.h"
#include "graph/operator_reg.h"
#include "graph/tensor.h"
#include "graph/utils/op_desc_utils.h"
#include "op_log.h"
#define LOG_ERROR(format, args...) printf(format, ##args)
#define LOG_INFO(format, args...) printf(format, ##args)
namespace ge {
enum OpImplType {
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op
EN_IMPL_CUSTOM_TIK, // custom tik op
EN_IMPL_CUSTOM_TBE, // custom tbe op
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op
EN_IMPL_HW_TIK, // Huawei built-in tik op
EN_IMPL_HW_TBE, // Huawei built-in tbe op
EN_IMPL_RL, // RL op
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op
EN_RESERVED // reserved value
};
const uint32_t SHAPE_NUMBER_16 = 16;
const uint32_t SHAPE_NUMBER_32 = 32;
const uint32_t SHAPE_DIM_VALUE_C04 = 4;
const uint32_t NI = 16;
const uint32_t MINUS_VALUE_ONE = 1;
const uint32_t MINUS_VALUE_TWO = 2;
const uint32_t SIZE_OF_CN = 2;
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2;
/* The first parameter is axis value, second is new shape and third is
* op implementation type. */
using GetNewShapeByAxisValueAndFormat =
std::function<bool(ge::GeShape&, const int64_t&, vector<int64_t>&, vector<int64_t>&)>;
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>;
struct ShapeAndFormatInfo {
const ge::GeShape& oldShape;
ge::GeShape& newShape;
const ge::Format& oldFormat;
const ge::Format& newFormat;
const ge::DataType& currentDataType;
const int64_t& opImplType;
};
using ShapeAndFormat = struct ShapeAndFormatInfo;
class ShapeTransferAccordingToFormat {
public:
ShapeTransferAccordingToFormat();
~ShapeTransferAccordingToFormat(){};
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete;
ShapeTransferAccordingToFormat& operator=(const ShapeTransferAccordingToFormat&) = delete;
bool GetShapeAccordingToFormat(ShapeAndFormat& inputAndOutputInfo, int64_t* c = nullptr);
/* ----------Below is the function of getting new shape---------------------- */
static bool GetNCHWShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue);
static bool GetNHWCShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue);
static bool GetNC1HWC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue);
static bool GetFzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue);
static bool GetHWCNShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue);
static bool GetC1HWNCoC0ShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType,
const vector<int64_t>& axisValue, const vector<int64_t>& ndValue);
static bool GetNzShapeByAxisValue(ge::GeShape& newShape, const int64_t& implType, const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue);
private:
/* map of GetAxisValueInfoByFormat, get axis value by different original
* formats. */
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap;
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0;
};
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_

+ 1097
- 0
tests/st/framework/stub_op_proto/util/util.cc
File diff suppressed because it is too large
View File


+ 363
- 0
tests/st/framework/stub_op_proto/util/util.h View File

@@ -0,0 +1,363 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file util.h
* \brief
*/
#ifndef OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
#define OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_
#include <memory.h>
#include <string>
#include <vector>
#include <map>
#include <algorithm>
#include "framework/omg/omg_inner_types.h"
#include "operator.h"
#include "graph/operator_reg.h"
#include "graph/operator_reg.h"
#include "transfer_shape_according_to_format.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/tensor.h"
#include "graph/node.h"
#include "graph/ge_tensor.h"
#include "op_log.h"
#define LOG_ERROR(format, args...) printf(format, ##args)
namespace ge {
// enum type and string type mapping
static const std::map<ge::DataType, std::string> DTYPE_STR_MAP{
{ge::DT_FLOAT16, "float16"}, {ge::DT_FLOAT, "float32"}, {ge::DT_INT8, "int8"}, {ge::DT_INT16, "int16"},
{ge::DT_INT32, "int32"}, {ge::DT_INT64, "int64"}, {ge::DT_UINT8, "uint8"}, {ge::DT_UINT16, "uint16"},
{ge::DT_UINT32, "uint32"}, {ge::DT_UINT64, "uint64"}, {ge::DT_BOOL, "bool"}};
// define the input num of shape
const size_t INPUT_NUM0 = 0;
const size_t INPUT_NUM1 = 1;
const size_t INPUT_NUM2 = 2;
const size_t INPUT_NUM3 = 3;
const size_t INPUT_NUM4 = 4;
const size_t INPUT_NUM5 = 5;
const size_t INPUT_NUM6 = 6;
const size_t INPUT_NUM7 = 7;
const size_t INPUT_NUM8 = 8;
const size_t INPUT_NUM9 = 9;
// define the dims size of shape
const size_t DIM_SIZE0 = 0;
const size_t DIM_SIZE1 = 1;
const size_t DIM_SIZE2 = 2;
const size_t DIM_SIZE3 = 3;
const size_t DIM_SIZE4 = 4;
const size_t DIM_SIZE5 = 5;
const size_t DIM_SIZE6 = 6;
const size_t DIM_SIZE7 = 7;
const size_t DIM_SIZE8 = 8;
// define the index of shape dim
const size_t DIM_INDEX0 = 0;
const size_t DIM_INDEX1 = 1;
const size_t DIM_INDEX2 = 2;
const size_t DIM_INDEX3 = 3;
const size_t DIM_INDEX4 = 4;
const size_t DIM_INDEX5 = 5;
const size_t DIM_INDEX6 = 6;
const size_t DIM_INDEX7 = 7;
const size_t DIM_INDEX8 = 8;
/*
* get the datatype of input
* param[in] dataType input datatype of enum value
* param[in] supportList the support range of op
* return true :get type success
* false:get type failed
*/
bool GetInputDataType(const ge::DataType& data_type, const std::vector<ge::DataType>& supportList);
bool GetInputDataType(const ge::DataType& dataType, const std::vector<ge::DataType>& supportList, std::string& dType);
/* infer shape of two input and on output with broadcast
* param[in] op op desc supply by ge
* param[in] inputName1 first input name
* param[in] inputName2 second input name
* param[in] outputName output name
* return SUCCESS:infer success
* FAILED:infer failed like unsupported broadcast input shape
*/
bool CheckInputDataType(const Operator& op, const std::string& input_name,
const std::vector<ge::DataType>& support_list);
/*
* check the datatype and shape of input
* param[in] op the operator
* param[in] inputTensorMap the map of input name and support datatype
* param[in] paramType the mode of input param, tensor or scalar
* return true
* false
*/
bool CheckInputDtypeAndShape(const Operator& op, const std::map<std::string, std::vector<DataType>>& inputTensorMap);
/*
* infer shape of two input and on output with broadcast
* param[in] op op desc supply by ge
* param[in] inputName1 first input name
* param[in] inputName2 second input name
* param[in] outputName output name
* return SUCCESS:infer success
* FAILED:infer failed like unsupported broadcast input shape
*/
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
const string& output_name);
/*
* infer shape of two input and on output with broadcast
* param[in] op op desc supply by ge
* param[in] inputName1 first input name
* param[in] inputName2 second input name
* param[in] outputName output name
* param[in] is_dynamic whether the shape of output is dynamic shape
* return SUCCESS:infer success
* FAILED:infer failed like unsupported broadcast input shape
*/
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
const string& output_name, bool& is_dynamic);
bool InferShapeRangeTwoInOneOutBroadcase(Operator& op, const string& input_name1, const string& input_name2,
const string& output_name);
bool CheckInputDataType(const Operator& op, std::string* data_type, const std::string& input_name,
const std::vector<ge::DataType>& supportList);
bool CheckTwoInputDtypeSame(const Operator& op, const string& input_name1, const string& input_name2);
bool CheckInputDtypeSame(const Operator& op, std::vector<std::string>& input_tensors);
bool CheckInputsShapeDtypeSame(const Operator& op, const std::vector<std::string>& input_names);
bool GetConstValue(const ge::Operator& op, const std::string& key_name, float& attr_value);
bool GetConstValue(const ge::Operator& op, const std::string& key_name, int64_t& attr_value);
bool GetConstValue(const ge::Operator& op, const std::string& key_name, bool& attr_value);
bool GetConstValue(const ge::Operator& op, const std::string& key_name, std::vector<int32_t>& attr_value);
/**
* Get int type const value from tensor data
* @param [in] data const tensor data
* @param [in] data_type DT_INT8, DT_INT16, DT_INT32, DT_INT64
* @param [out] const_values const int values
* @return true:success, false:failed.
*/
bool GetConstIntData(const Tensor& data, DataType data_type, std::vector<int64_t>& const_values);
bool GetConstValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype,
std::vector<int64_t>& const_data);
bool GetConstValue(const Operator& op, const GeTensorPtr& const_tensor, const DataType& dtype,
std::vector<int64_t>& const_data);
bool GetScalerValue(const Operator& op, const Tensor& const_tensor, const DataType& dtype, std::int64_t& const_data);
bool InferShapeAndTypeTwoInOneOutBroadcast(Operator& op, const string& input_name1, const string& input_name2,
const string& output_name);
/*
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd
* param[in] op op desc supply by ge
* param[in] inputNumBeg input index begin, [0, N]
* param[in] inputNumEnd input index end need to be checked
* param[in] supportList, support type of ge::DataType and ge::Format
* return true: check pass
* false: check failed
*/
template <typename T>
bool CheckSimilarInputDtypeAndFormat(const Operator& op, std::size_t inputNumBeg, std::size_t inputNumEnd,
const std::vector<T>& supportList) {
for (std::size_t i = inputNumBeg; i < inputNumEnd; i++) {
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) {
ge::DataType inType = op.GetInputDesc(i).GetDataType();
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
if (findDtype == supportList.end()) {
return false;
}
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) {
ge::Format inType = op.GetInputDesc(i).GetFormat();
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
if (findDtype == supportList.end()) {
return false;
}
}
}
return true;
}
/*
* Check input dtype and format is supported in supportList from inputNumBeg to inputNumEnd
* param[in] op op desc supply by ge
* param[in] indexNeedCheck input index need to be checked
* param[in] supportList, support type of ge::DataType and ge::Format
* return true: check pass
* false: check failed
*/
template <typename T>
bool CheckSimilarInputDtypeAndFormat(const Operator& op, const std::vector<std::size_t>& indexNeedCheck,
const std::vector<T>& supportList) {
for (auto i : indexNeedCheck) {
if (std::is_same<typename std::decay<T>::type, ge::DataType>::value) {
ge::DataType inType = op.GetInputDesc(i).GetDataType();
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
if (findDtype == supportList.end()) {
return false;
}
} else if (std::is_same<typename std::decay<T>::type, ge::Format>::value) {
ge::Format inType = op.GetInputDesc(i).GetFormat();
const auto& findDtype = std::find(supportList.begin(), supportList.end(), inType);
if (findDtype == supportList.end()) {
return false;
}
}
}
return true;
}
/*
* get const attr
* param[in] op op desc supply by ge
* param[in] attrName list need to be get
* param[out] attr vector
* return true: get success
* false: get failed
*/
template <typename T>
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList, std::vector<T>& attrVec) {
T value;
for (auto name : attrNameList) {
if (op.GetAttr(name, value) != ge::GRAPH_SUCCESS) {
return false;
}
attrVec.push_back(value);
}
return true;
}
/*
* get const attr list
* param[in] op op desc supply by ge
* param[in] attrName list need to be get
* param[out] attr vector
* return true: get success
* false: get failed
*/
template <typename T>
bool GetConstAttr(const Operator& op, const std::vector<std::string>& attrNameList,
std::vector<std::vector<T>>& attrListVec) {
for (auto name : attrNameList) {
std::vector<T> valueList;
if (op.GetAttr(name, valueList) != ge::GRAPH_SUCCESS) {
return false;
}
attrListVec.push_back(valueList);
}
return true;
}
std::string to_string(const vector<int64_t>& shape);
std::string to_string(const ge::Shape& shape);
std::string to_string(const ge::GeShape& shape);
std::string to_string(const vector<pair<int64_t, int64_t>>& ranges);
class DynamicShapeInfer {
public:
std::map<std::string, Format> map_format;
std::map<std::string, DataType> map_dtype;
std::map<std::string, uint32_t> inputs;
std::map<std::string, uint32_t> outputs;
Operator& op;
OpDescPtr& op_desc;
DynamicShapeInfer(Operator& op_v, OpDescPtr& opDesc_v) : op(op_v), op_desc(opDesc_v) {
}
bool CatchFormatAndShape();
bool UpdateFormatAndShape();
~DynamicShapeInfer() {
UpdateFormatAndShape();
}
};
#define PREPARE_DYNAMIC_SHAPE(depends_names) auto op_desc = OpDescUtils::GetOpDescFromOperator(op);\
do { \
if (!depends_names.empty()) { \
op_desc->SetOpInferDepends(depends_names); \
} \
} while(0)
bool IsEmptyTensor(const std::vector<int64_t>& dims);
bool IsUnknownRank(const Operator& op, const std::string& tensor_name, const std::string& types = "input");
bool IsUnknownRankShape(const std::vector<int64_t>& shape_vec);
bool IsUnKnownShape(const std::vector<int64_t>& shape_vec);
bool IsUnknownShape(const Operator& op, const std::string& tensor_name, const std::string& types = "input");
bool IsUnknownVec(std::vector<int64_t>& shape_vec);
bool IsUnknown(const std::vector<int64_t>& shape_vec);
void MakeUpShapeRange(const std::vector<int64_t>& shape, std::vector<std::pair<int64_t, int64_t>>& range);
std::string DataTypeToStringDesc(const ge::DataType& dataType);
bool OneInOneOutDynamicInfer(const Operator& op,
const std::string& input_name,
const std::vector<std::string>& output_name_list);
bool TwoInOneOutDynamicInferNoBroadcast(Operator& op,
const string& input1_name,
const string& input2_name,
const std::vector<string>& output_name_list);
void FixShapeRangeWithDims(const std::vector<int64_t>& dims,
std::vector<int64_t>& shape_1,
std::vector<int64_t>& shape_2,
std::vector<std::pair<int64_t, int64_t>>& range_1,
std::vector<std::pair<int64_t, int64_t>>& range_2);
bool SetScalarOutputDesc(const string& input,
const string& output,
OpDescPtr op_desc,
GeShape& output_shape);
namespace array_ops {
bool CheckInt64MulOverflow(int64_t a, int64_t b);
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
int64_t& range_max);
void ReshapeRangeInfer(const Operator &op, const std::vector<std::pair<int64_t, int64_t>>& x_range,
std::vector<std::pair<int64_t, int64_t>>& y_range, GeShape& output_shape);
}
} // namespace ge
#endif // OPS_BUILT_IN_OP_PROTO_UTIL_UTIL_H_

+ 17
- 0
tests/st/framework/utils/assertion/graph_assertion.cc View File

@@ -0,0 +1,17 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph_assertion.h"

+ 34
- 0
tests/st/framework/utils/assertion/graph_assertion.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H
#define GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H
/*
* Compare graph node size, node_attr
*/
#define ASSERT_GRAPH_EQUAL(g1,g2) \
do { \
} while (0)
#define ASSERT_GRAPH_CORRECT(g) \
do { \
} while (0)
#define ASSERT_GRAPH_SHAPE_CONTINOUS(g) \
do { \
} while (0)
#endif // GRAPHENGINE_LLT_ST_GRAPH_ASSERTION_H

+ 48
- 0
tests/st/framework/utils/builder/graph_builder_utils.cc View File

@@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph_builder_utils.h"
#include "inc/external/graph/operator.h"
#include "inc/external/graph/operator_factory.h"
#include "graph/utils/graph_utils.h"

namespace ge {
namespace st {
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format,
DataType data_type, std::vector<int64_t> shape) {
auto tensor_desc = std::make_shared<GeTensorDesc>();
tensor_desc->SetShape(GeShape(std::move(shape)));
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);

auto op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < in_cnt; ++i) {
op_desc->AddInputDesc(tensor_desc->Clone());
}
for (int i = 0; i < out_cnt; ++i) {
op_desc->AddOutputDesc(tensor_desc->Clone());
}

return graph_->AddNode(op_desc);
}
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) {
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
}
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
}
} // namespace st
} // namespace ge

+ 53
- 0
tests/st/framework/utils/builder/graph_builder_utils.h View File

@@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

#include <string>
#include <vector>

#include "graph/compute_graph.h"
#include "graph/utils/graph_utils.h"
#include "graph/graph.h"
#include "graph/node.h"

namespace ge {
namespace st {
class ComputeGraphBuilder {
public:
explicit ComputeGraphBuilder(const std::string &name) { graph_ = std::make_shared<ComputeGraph>(name); }
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
std::vector<int64_t> shape = {1, 1, 224, 224});
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx);
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node);
ComputeGraphPtr GetComputeGraph() {
graph_->TopologicalSorting();
return graph_;
}
Graph GetGraph() {
graph_->TopologicalSorting();
return GraphUtils::CreateGraphFromComputeGraph(graph_);
}

private:
ComputeGraphPtr graph_;
};
} // namespace st
} // namespace ge

#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H

+ 17
- 0
tests/st/framework/utils/builder/tensor_builder_utils.cc View File

@@ -0,0 +1,17 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensor_builder_utils.h"

+ 22
- 0
tests/st/framework/utils/builder/tensor_builder_utils.h View File

@@ -0,0 +1,22 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H
class tensor_builder_utils {};
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H

+ 15
- 0
tests/st/testcase/CMakeLists.txt View File

@@ -0,0 +1,15 @@
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++")

add_executable(graph_engine_test ${SOURCES})

target_include_directories(graph_engine_test
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}
)

set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 11)

target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework)

include(CTest)
enable_testing()
add_test(NAME test COMMAND graph_engine_test)

+ 58
- 0
tests/st/testcase/test_framework_dummy.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <map>
#include "external/ge/ge_api.h"
#include "framework/common/types.h"
#include "framework.h"
#include "framework/utils/builder/graph_builder_utils.h"

using namespace std;
using namespace ge;
class FrameworkTest : public testing::Test {
protected:
void SetUp() {
// ge initialize
map<AscendString, AscendString> options;
auto ret = ge::GEInitialize(options);
EXPECT_EQ(ret, SUCCESS);
}
void TearDown() {}
};

TEST_F(FrameworkTest, test_framework_dummy) {
// build graph
st::ComputeGraphBuilder graphBuilder("g1");
auto data1 = graphBuilder.AddNode("data1",DATA,1,1);
auto data2 = graphBuilder.AddNode("data2",DATA,1,1);
auto add = graphBuilder.AddNode("add",ADD,2,1);
graphBuilder.AddDataEdge(data1, 0, add,0);
graphBuilder.AddDataEdge(data2, 0, add,1);
Graph graph = graphBuilder.GetGraph();

// new session & add graph
map<AscendString, AscendString> options;
Session session(options);
auto ret = session.AddGraph(1, graph, options);
EXPECT_EQ(ret, SUCCESS);
// build input tensor
std::vector<InputTensorInfo> inputs;
// build_graph through session
ret = session.BuildGraph(1, inputs);

// TODO check result
}

Loading…
Cancel
Save