Browse Source

!320 update ge_runtime, modify cmake files for compiling with MindSpore

From: @ljl0711
Reviewed-by: @youui,@liujunzhu
Signed-off-by: @liujunzhu
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
22730223f8
37 changed files with 790 additions and 168 deletions
  1. +44
    -13
      CMakeLists.txt
  2. +3
    -0
      cmake/external_libs/protobuf_static.cmake
  3. +147
    -74
      ge/CMakeLists.txt
  4. +52
    -0
      ge/common/CMakeLists.txt
  5. +17
    -8
      ge/ge_runtime/CMakeLists.txt
  6. +8
    -3
      ge/ge_runtime/model_context.h
  7. +2
    -1
      ge/ge_runtime/model_runner.cc
  8. +2
    -1
      ge/ge_runtime/output.cc
  9. +3
    -3
      ge/ge_runtime/output.h
  10. +35
    -21
      ge/ge_runtime/runtime_model.cc
  11. +5
    -6
      ge/ge_runtime/runtime_model.h
  12. +47
    -5
      ge/ge_runtime/task/aicpu_task.cc
  13. +1
    -0
      ge/ge_runtime/task/aicpu_task.h
  14. +5
    -10
      ge/ge_runtime/task/cce_task.cc
  15. +0
    -0
      ge/ge_runtime/task/cce_task.h
  16. +1
    -1
      ge/ge_runtime/task/event_record_task.h
  17. +1
    -1
      ge/ge_runtime/task/event_wait_task.cc
  18. +1
    -1
      ge/ge_runtime/task/event_wait_task.h
  19. +0
    -3
      ge/ge_runtime/task/hccl_task.cc
  20. +0
    -0
      ge/ge_runtime/task/hccl_task.h
  21. +70
    -0
      ge/ge_runtime/task/label_goto_task.cc
  22. +41
    -0
      ge/ge_runtime/task/label_goto_task.h
  23. +70
    -0
      ge/ge_runtime/task/label_set_task.cc
  24. +41
    -0
      ge/ge_runtime/task/label_set_task.h
  25. +131
    -0
      ge/ge_runtime/task/label_switch_task.cc
  26. +44
    -0
      ge/ge_runtime/task/label_switch_task.h
  27. +0
    -0
      ge/ge_runtime/task/memcpy_async_task.h
  28. +0
    -0
      ge/ge_runtime/task/profiler_task.h
  29. +0
    -0
      ge/ge_runtime/task/stream_active_task.h
  30. +1
    -0
      ge/ge_runtime/task/stream_switch_task.h
  31. +1
    -1
      ge/ge_runtime/task/task.h
  32. +1
    -0
      ge/ge_runtime/task/task_factory.h
  33. +0
    -0
      ge/ge_runtime/task/tbe_task.cc
  34. +0
    -0
      ge/ge_runtime/task/tbe_task.h
  35. +7
    -7
      inc/framework/ge_runtime/davinci_model.h
  36. +2
    -5
      inc/framework/ge_runtime/model_runner.h
  37. +7
    -4
      inc/framework/ge_runtime/task_info.h

+ 44
- 13
CMakeLists.txt View File

@@ -8,6 +8,19 @@ if (NOT BUILD_PATH)
set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build")
endif() endif()


if(DEFINED ENV{ASCEND_CUSTOM_PATH})
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH})
else()
set(ASCEND_DIR /usr/local/Ascend)
endif()
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})

option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE)


if (ENABLE_OPEN_SRC) if (ENABLE_OPEN_SRC)
@@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC)
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated")
endif() endif()
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH})
set(STATIC_ACL_LIB ${GE_LIB_PATH})
set(STATIC_ACL_LIB ${GE_LIB_PATH})
find_module(slog libslog.so ${GE_LIB_PATH}) find_module(slog libslog.so ${GE_LIB_PATH})
find_module(mmpa libmmpa.so ${GE_LIB_PATH}) find_module(mmpa libmmpa.so ${GE_LIB_PATH})
find_module(msprof libmsprof.so ${GE_LIB_PATH}) find_module(msprof libmsprof.so ${GE_LIB_PATH})
@@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC)
find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH})
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else() else()
if(DEFINED ENV{ASCEND_CUSTOM_PATH})
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH})
else()
set(ASCEND_DIR /usr/local/Ascend)
endif()
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
find_module(slog libslog.so ${ASCEND_ATC_DIR}) find_module(slog libslog.so ${ASCEND_ATC_DIR})
find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR})
if(PLATFORM STREQUAL "train") if(PLATFORM STREQUAL "train")
@@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC)
add_subdirectory(parser) add_subdirectory(parser)
#add_subdirectory(metadef/graph) #add_subdirectory(metadef/graph)
#add_subdirectory(metadef/register) #add_subdirectory(metadef/register)
elseif (ENABLE_D OR ENABLE_ACL)
# compiling with MindSpore
include(cmake/external_libs/protobuf_static.cmake)
include(cmake/external_libs/protoc.cmake)
include(cmake/external_libs/securec.cmake)
include(cmake/external_libs/json.cmake)
include(cmake/FindModule.cmake)
include(cmake/intf_pub_linux.cmake)

# common modules
find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR})
find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR})

if (ENABLE_D)
# training
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR})
find_module(register libregister.so ${ASCEND_RUNTIME_DIR})
find_module(resource libresource.so ${ASCEND_RUNTIME_DIR})
elseif(ENABLE_ACL)
# inference
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR})
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(resource libresource.so ${ASCEND_ATC_DIR})
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR})
endif ()

set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef)
add_subdirectory(metadef)
else() else()
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef)
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser)


+ 3
- 0
cmake/external_libs/protobuf_static.cmake View File

@@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES
add_library(ascend_protobuf_static INTERFACE) add_library(ascend_protobuf_static INTERFACE)
target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include)
target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib)
if (ENABLE_D OR ENABLE_ACL)
include_directories(${PROTOBUF_STATIC_PKG_DIR}/include)
endif ()


add_dependencies(ascend_protobuf_static protobuf_static_build) add_dependencies(ascend_protobuf_static protobuf_static_build)

+ 147
- 74
ge/CMakeLists.txt View File

@@ -1,10 +1,15 @@
add_subdirectory(common)
add_subdirectory(plugin/engine)
add_subdirectory(graph/build/memory)
add_subdirectory(ge_local_engine)
add_subdirectory(host_cpu_engine)
add_subdirectory(executor)
add_subdirectory(offline)
if (NOT ENABLE_D AND NOT ENABLE_ACL)
add_subdirectory(common)
add_subdirectory(plugin/engine)
add_subdirectory(graph/build/memory)
add_subdirectory(ge_local_engine)
add_subdirectory(host_cpu_engine)
add_subdirectory(executor)
add_subdirectory(offline)
else()
add_subdirectory(common)
add_subdirectory(ge_runtime)
endif ()


set(PROTO_LIST set(PROTO_LIST
"${METADEF_DIR}/proto/fusion_model.proto" "${METADEF_DIR}/proto/fusion_model.proto"
@@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST})
protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST})


############ libge_runner.so ############
set(TRAIN_SRC_LIST set(TRAIN_SRC_LIST
"common/formats/format_transfers/datatype_transfer.cc" "common/formats/format_transfers/datatype_transfer.cc"
"common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc"
@@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST
"ir_build/atc_ir_common.cc" "ir_build/atc_ir_common.cc"
) )


add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS})

target_compile_definitions(ge_runner PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
DAVINCI_SUPPORT_PROFILING
REUSE_MEMORY=1
FMK_SUPPORT_DUMP
DAVINCI_CLOUD
google=ascend_private
)

target_compile_options(ge_runner PRIVATE
-O2
)

target_include_directories(ge_runner PRIVATE
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/ge/analyzer
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${GE_CODE_DIR}/inc/framework/common
${METADEF_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/external
${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
#### blue zone
${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
)

target_link_libraries(ge_runner
$<BUILD_INTERFACE:intf_pub>
ge_memory
adump_server
msprofiler
-Wl,--no-as-needed
graph
ge_common
ascend_protobuf
register
c_sec
slog
mmpa
msprof
runtime
resource
error_manager
ascend_hal_stub
-Wl,--as-needed
json
-lrt
-ldl
)

############ libge_compiler.so ############
set(INFER_SRC_LIST set(INFER_SRC_LIST
"graph/manager/trans_var_data_utils.cc" "graph/manager/trans_var_data_utils.cc"
"omm/csa_interact.cc" "omm/csa_interact.cc"
@@ -662,6 +600,74 @@ set(INFER_SRC_LIST
"analyzer/analyzer.cc" "analyzer/analyzer.cc"
) )


if (NOT ENABLE_D AND NOT ENABLE_ACL)
############ libge_runner.so ############
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS})

target_compile_definitions(ge_runner PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
DAVINCI_SUPPORT_PROFILING
REUSE_MEMORY=1
FMK_SUPPORT_DUMP
DAVINCI_CLOUD
google=ascend_private
)

target_compile_options(ge_runner PRIVATE
-O2
)

target_include_directories(ge_runner PRIVATE
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/ge/analyzer
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${GE_CODE_DIR}/inc/framework/common
${METADEF_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/external
${GE_CODE_DIR}/../inc/cce
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external
#### blue zone
${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
)

target_link_libraries(ge_runner
$<BUILD_INTERFACE:intf_pub>
ge_memory
adump_server
msprofiler
-Wl,--no-as-needed
graph
ge_common
ascend_protobuf
register
c_sec
slog
mmpa
msprof
runtime
resource
error_manager
ascend_hal_stub
-Wl,--as-needed
json
-lrt
-ldl
)

############ libge_compiler.so ############
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS})


target_compile_definitions(ge_compiler PRIVATE target_compile_definitions(ge_compiler PRIVATE
@@ -919,3 +925,70 @@ install(FILES
${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR} DESTINATION ${INSTALL_LIBRARY_DIR}
) )

elseif (ENABLE_ACL)

############ libge_compiler.so w/static protobuf ############
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS})

target_compile_definitions(ge_compiler PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_SUPPORT_DUMP
FMK_HOST_INFER
COMPILE_OMG_PACKAGE
google=ascend_private
)

target_compile_options(ge_compiler PRIVATE
-O2
)

target_include_directories(ge_compiler PRIVATE
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/ge/analyzer
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${GE_CODE_DIR}/inc/framework/common
${METADEF_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${ASCEND_DIR}/driver/include
${ASCEND_DIR}/fwkacllib/include
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
)

target_link_libraries(ge_compiler
$<BUILD_INTERFACE:intf_pub>
ge_memory
-Wl,--no-as-needed
graph
ge_common
static_ascend_protobuf
register
c_sec
error_manager
slog
mmpa
runtime_compile
resource
-Wl,--as-needed
json
-lrt
-ldl
)

############ install libge_compiler for MindSpore############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS ge_compiler OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)
endif()

+ 52
- 0
ge/common/CMakeLists.txt View File

@@ -63,6 +63,7 @@ set(SRC_LIST
"ge/tbe_plugin_manager.cc" "ge/tbe_plugin_manager.cc"
) )


if (NOT ENABLE_D AND NOT ENABLE_ACL)
############ libge_common.so ############ ############ libge_common.so ############
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS})
target_compile_definitions(ge_common PRIVATE target_compile_definitions(ge_common PRIVATE
@@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE
-ldl -ldl
) )


else ()
############ libge_common.so w/static protobuf ############
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS})
target_compile_definitions(ge_common PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
HOST_VISIBILITY
FMK_SUPPORT_DUMP
OS_CENTOS
google=ascend_private
)

target_compile_options(ge_common PRIVATE
-fvisibility=hidden
-O2
-Werror
)

target_include_directories(ge_common PRIVATE
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/ge/common
${GE_CODE_DIR}/ge/common/op
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/framework
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${GE_CODE_DIR}/third_party/fwkacllib/inc
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
)

target_link_libraries(ge_common PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf_static
-Wl,--no-as-needed
graph
register
c_sec
error_manager
slog
mmpa
-Wl,--as-needed
json
-lrt
-ldl
)
endif ()

############ install ############ ############ install ############
set(INSTALL_BASE_DIR "") set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib) set(INSTALL_LIBRARY_DIR lib)


+ 17
- 8
ge/ge_runtime/CMakeLists.txt View File

@@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE
) )


target_include_directories(ge_runtime PRIVATE target_include_directories(ge_runtime PRIVATE
${TOP_DIR}
${TOP_DIR}/inc
${TOP_DIR}/inc/graph
${TOP_DIR}/inc/external
${TOP_DIR}/inc/framework
${TOP_DIR}/inc/framework/common
${TOP_DIR}/inc/framework/ge_runtime
${TOP_DIR}/inc/cce
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}
${GE_CODE_DIR}/ge
${GE_CODE_DIR}/inc
${GE_CODE_DIR}/inc/graph
${GE_CODE_DIR}/inc/external
${GE_CODE_DIR}/inc/framework
${GE_CODE_DIR}/inc/framework/common
${GE_CODE_DIR}/inc/framework/ge_runtime
${GE_CODE_DIR}/inc/cce
${GE_CODE_DIR}/third_party/fwkacllib/inc
${METADEF_DIR}
${METADEF_DIR}/inc
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge ${CMAKE_BINARY_DIR}/proto/ge
) )
@@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE
slog slog
runtime runtime
c_sec c_sec
graph
-Wl,--as-needed -Wl,--as-needed
-lrt -lrt
-ldl -ldl


+ 8
- 3
ge/ge_runtime/model_context.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -27,8 +27,13 @@ class ModelContext {
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle,
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list,
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list)
: device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle),
rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list),
: device_id_(device_id),
session_id_(session_id),
priority_(priority),
rt_model_handle_(rt_model_handle),
rt_model_stream_(rt_model_stream),
stream_list_(stream_list),
label_list_(label_list),
event_list_(event_list) {} event_list_(event_list) {}
~ModelContext() {} ~ModelContext() {}




+ 2
- 1
ge/ge_runtime/model_runner.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {

using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; using RuntimeModelPtr = std::shared_ptr<RuntimeModel>;
using DavinciModelPtr = std::shared_ptr<DavinciModel>; using DavinciModelPtr = std::shared_ptr<DavinciModel>;




+ 2
- 1
ge/ge_runtime/output.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat
bool support_mem_share) { bool support_mem_share) {
return true; return true;
} }

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge

+ 3
- 3
ge/ge_runtime/output.h View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {

class Output { class Output {
public: public:
Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model);
@@ -32,8 +33,7 @@ class Output {


bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share);


bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i,
bool support_mem_share);
bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share);


// Copy assignment operator and copy constructor are deleted // Copy assignment operator and copy constructor are deleted
Output &operator=(const Output &output) = delete; Output &operator=(const Output &output) = delete;


+ 35
- 21
ge/ge_runtime/runtime_model.cc View File

@@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) {
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) {
rtStream_t stream = nullptr; rtStream_t stream = nullptr;
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end())
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)
: (RT_STREAM_PERSISTENT);
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)
: (RT_STREAM_PERSISTENT);


rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
@@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) {
return true; return true;
} }


bool RuntimeModel::InitLabel(uint32_t batch_num) {
GELOGI("batch number:%u.", batch_num);
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) {
rtLabel_t rt_lLabel = nullptr;
rtError_t rt_ret = rtLabelCreate(&rt_lLabel);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret);
return false;
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("batch number:%u.", davinci_model->GetBatchNum());
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
if (task_info == nullptr) {
GELOGE(PARAM_INVALID, "task_info is null.");
continue;
}

if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
} }
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);


if (rt_lLabel == nullptr) {
GELOGE(RT_FAILED, "rtLabel is nullptr!");
if (label_set_task_info->stream_id() >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "Invalid stream id.");
return false; return false;
} }


label_list_.emplace_back(rt_lLabel);
rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret);
return false;
}
label_list_[label_set_task_info->label_id()] = rt_label;
} }

return true; return true;
} }


@@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
return false; return false;
} }


if (!InitLabel(davinci_model->GetBatchNum())) {
if (!InitLabel(davinci_model)) {
return false; return false;
} }


@@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() {
GELOGE(FAILED, "DistributeTask failed"); GELOGE(FAILED, "DistributeTask failed");
return false; return false;
} }

return true; return true;
} }


@@ -293,10 +303,14 @@ bool RuntimeModel::Run() {
return false; return false;
} }


GELOGI("Run rtModelExecute success");
GELOGI("Run rtModelExecute success, ret = 0x%X", ret);


ret = rtStreamSynchronize(rt_model_stream_); ret = rtStreamSynchronize(rt_model_stream_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
if (ret == RT_ERROR_END_OF_SEQUENCE) {
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret);
return true;
}
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret);
return false; return false;
} }
@@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept {


void RuntimeModel::RtLabelDestory() noexcept { void RuntimeModel::RtLabelDestory() noexcept {
for (size_t i = 0; i < label_list_.size(); i++) { for (size_t i = 0; i < label_list_.size(); i++) {
if (label_list_[i] == nullptr) {
continue;
}
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i);
return; return;
@@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero
/// and that of unknown shape is zero too. /// and that of unknown shape is zero too.
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not.
int64_t elem_num = constant->weight_tensors[0].GetShapeSize();
if (elem_num == 0 && constant->weight_tensors[0].size == 0) {
elem_num = 1;
}

int64_t elem_num =
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize();
if (constant->weight_data.size() < sizeof(uint64_t)) { if (constant->weight_data.size() < sizeof(uint64_t)) {
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)");
return false; return false;


+ 5
- 6
ge/ge_runtime/runtime_model.h View File

@@ -40,13 +40,11 @@ class RuntimeModel {
const std::vector<uint32_t> &GetTaskIdList() const; const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const; const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
const rtModel_t GetModelHandle() const { return rt_model_handle_; }
rtModel_t GetModelHandle() const { return rt_model_handle_; }
bool Run(); bool Run();
bool CopyInputData(const InputData &input_data); bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy,
std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc,
std::vector<uint32_t> *input_format,
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format,
std::vector<uint32_t> *output_format); std::vector<uint32_t> *output_format);


private: private:
@@ -55,7 +53,7 @@ class RuntimeModel {
bool LoadTask(); bool LoadTask();
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); bool InitStream(std::shared_ptr<DavinciModel> &davinci_model);
bool InitEvent(uint32_t event_num); bool InitEvent(uint32_t event_num);
bool InitLabel(uint32_t batch_num);
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model);
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model);
@@ -87,6 +85,7 @@ class RuntimeModel {
std::vector<uint32_t> stream_id_list_{}; std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
}; };

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge




+ 47
- 5
ge/ge_runtime/task/aicpu_task.cc View File

@@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai
task_info_(task_info), task_info_(task_info),
stream_(nullptr), stream_(nullptr),
args_(nullptr), args_(nullptr),
ext_info_(nullptr),
input_output_addr_(nullptr) { input_output_addr_(nullptr) {
if (task_info_ == nullptr) { if (task_info_ == nullptr) {
GELOGW("task_info_ is null!"); GELOGW("task_info_ is null!");
@@ -41,7 +42,10 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai
} }
} }


AicpuTask::~AicpuTask() { ReleaseRtMem(&args_); }
AicpuTask::~AicpuTask() {
ReleaseRtMem(&args_);
ReleaseRtMem(&ext_info_);
}


bool AicpuTask::Distribute() { bool AicpuTask::Distribute() {
GELOGI("InitAicpuTask start."); GELOGI("InitAicpuTask start.");
@@ -51,10 +55,37 @@ bool AicpuTask::Distribute() {
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); auto io_addrs_num = static_cast<uint32_t>(io_addrs.size());
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *));
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead);
uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size;
uint32_t args_size =
sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size());
aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num};
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size;
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t);
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size +
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t);

aicpu::AicpuParamHead aicpu_param_head;
aicpu_param_head.length = args_size;
aicpu_param_head.ioAddrNum = io_addrs_num;
auto ext_info = task_info_->ext_info();
uint32_t ext_size = ext_info.size();
if (ext_info.empty()) {
aicpu_param_head.extInfoLength = 0;
aicpu_param_head.extInfoAddr = 0;
} else {
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag);
return false;
}

flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size,
RT_MEMCPY_HOST_TO_DEVICE);
if (flag != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag);
return false;
}

GELOGI("ext info size:", ext_size);
aicpu_param_head.extInfoLength = ext_size;
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_);
}


// Malloc device memory for args // Malloc device memory for args
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
@@ -80,6 +111,17 @@ bool AicpuTask::Distribute() {
return false; return false;
} }
} }

// Memcpy node def
auto size = task_info_->node_def().size();
rt_ret =
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t),
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret);
return false;
}

// Memcpy node def // Memcpy node def
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset),
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()),


+ 1
- 0
ge/ge_runtime/task/aicpu_task.h View File

@@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
std::shared_ptr<AicpuTaskInfo> task_info_; std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_; void *stream_;
void *args_; void *args_;
void *ext_info_;
void *input_output_addr_; void *input_output_addr_;
}; };
} // namespace model_runner } // namespace model_runner


+ 5
- 10
ge/ge_runtime/task/cce_task.cc View File

@@ -103,9 +103,9 @@ bool CceTask::Distribute() {
// Modify flowtable addr in args // Modify flowtable addr in args
auto args = const_cast<uint8_t *>(task_info_->args().data()); auto args = const_cast<uint8_t *>(task_info_->args().data());
auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data()));

if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) {
GELOGE(FAILED,
"(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu",
GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu",
static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size());
return false; return false;
} }
@@ -136,8 +136,7 @@ bool CceTask::Distribute() {
return false; return false;
} }


rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(),
task_info_->sm_desc().data(),
rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(),
task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
@@ -146,12 +145,8 @@ bool CceTask::Distribute() {
} }


// Kernel launch // Kernel launch
rt_ret = rtKernelLaunch(stub_func_,
task_info_->block_dim(),
args_,
task_info_->args_size(),
static_cast<rtSmDesc_t *>(sm_desc_),
stream_);
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(),
static_cast<rtSmDesc_t *>(sm_desc_), stream_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false; return false;


+ 0
- 0
ge/ge_runtime/task/cce_task.h View File


+ 1
- 1
ge/ge_runtime/task/event_record_task.h View File

@@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> {
private: private:
std::shared_ptr<EventRecordTaskInfo> task_info_; std::shared_ptr<EventRecordTaskInfo> task_info_;
rtStream_t stream_; rtStream_t stream_;
rtEvent_t event_;
rtEvent_t event_;
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 1
- 1
ge/ge_runtime/task/event_wait_task.cc View File

@@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.


+ 1
- 1
ge/ge_runtime/task/event_wait_task.h View File

@@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> {
private: private:
std::shared_ptr<EventWaitTaskInfo> task_info_; std::shared_ptr<EventWaitTaskInfo> task_info_;
rtStream_t stream_; rtStream_t stream_;
rtEvent_t event_;
rtEvent_t event_;
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 0
- 3
ge/ge_runtime/task/hccl_task.cc View File

@@ -115,7 +115,6 @@ bool HcclTask::Distribute() {
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
(void)rtStreamDestroy(stream);
return false; return false;
} }


@@ -129,8 +128,6 @@ bool HcclTask::Distribute() {
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL);
ge_task.stream = stream_; ge_task.stream = stream_;


GETaskKernelHcclInfo kernel_hccl_info;
ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info);
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type();
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr();
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr();


+ 0
- 0
ge/ge_runtime/task/hccl_task.h View File


+ 70
- 0
ge/ge_runtime/task/label_goto_task.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelGotoTask::~LabelGotoTask() {}

bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);

} // namespace model_runner
} // namespace ge

+ 41
- 0
ge/ge_runtime/task/label_goto_task.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);

~LabelGotoTask() override;

bool Distribute() override;

private:
std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

+ 70
- 0
ge/ge_runtime/task/label_set_task.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}

LabelSetTask::~LabelSetTask() {}

bool LabelSetTask::Distribute() {
GELOGI("LabelSetTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);

} // namespace model_runner
} // namespace ge

+ 41
- 0
ge/ge_runtime/task/label_set_task.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);

~LabelSetTask() override;

bool Distribute() override;

private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

+ 131
- 0
ge/ge_runtime/task/label_switch_task.cc View File

@@ -0,0 +1,131 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"

namespace ge {
namespace model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}

all_label_resource_ = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid.");
return;
}
stream_ = stream_list[stream_id];
}

LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
}
}

bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}

const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);

for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
}

uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("DistributeTask end.");
return true;
}

bool LabelSwitchTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}

if (task_info_->label_list().empty()) {
GELOGE(PARAM_INVALID, "label_list is empty.");
return false;
}

if (task_info_->label_size() != task_info_->label_list().size()) {
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(),
task_info_->label_size());
return false;
}

if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size());
return false;
}

if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}

return true;
}

REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);

} // namespace model_runner
} // namespace ge

+ 44
- 0
ge/ge_runtime/task/label_switch_task.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

#include <memory>
#include "ge_runtime/task/task.h"

namespace ge {
namespace model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);

~LabelSwitchTask() override;

bool Distribute() override;

private:
bool CheckParamValid();

std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
};
} // namespace model_runner
} // namespace ge

#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_

+ 0
- 0
ge/ge_runtime/task/memcpy_async_task.h View File


+ 0
- 0
ge/ge_runtime/task/profiler_task.h View File


+ 0
- 0
ge/ge_runtime/task/stream_active_task.h View File


+ 1
- 0
ge/ge_runtime/task/stream_switch_task.h View File

@@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> {
void *stream_; void *stream_;
std::vector<rtStream_t> stream_list_; std::vector<rtStream_t> stream_list_;
}; };

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge
#endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_

+ 1
- 1
ge/ge_runtime/task/task.h View File

@@ -42,7 +42,7 @@ class Task {


template <class T> template <class T>
class TaskRepeater : public Task { class TaskRepeater : public Task {
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); /*lint !e30*/
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!");


public: public:
TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {}


+ 1
- 0
ge/ge_runtime/task/task_factory.h View File

@@ -81,6 +81,7 @@ class TaskFactory {
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \
return std::make_shared<task_clazz>(model_context, concrete_task_info); \ return std::make_shared<task_clazz>(model_context, concrete_task_info); \
}); });

} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge
#endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_

+ 0
- 0
ge/ge_runtime/task/tbe_task.cc View File


+ 0
- 0
ge/ge_runtime/task/tbe_task.h View File


+ 7
- 7
inc/framework/ge_runtime/davinci_model.h View File

@@ -27,10 +27,10 @@ namespace ge {
namespace model_runner { namespace model_runner {
class DavinciModel { class DavinciModel {
public: public:
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, /*lint !e151*/
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list,
const std::vector<std::shared_ptr<OpInfo>> &data_info_list, const std::vector<std::shared_ptr<OpInfo>> &data_info_list,
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, /*lint !e151*/
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, /*lint !e1049*/
const std::vector<std::shared_ptr<OpInfo>> &output_info_list,
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list,
const std::vector<model_runner::OpInfoPtr> &variable_info_list, const std::vector<model_runner::OpInfoPtr> &variable_info_list,
const std::vector<uint32_t> &wait_active_stream_list, const std::vector<uint32_t> &wait_active_stream_list,
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0,
@@ -68,12 +68,12 @@ class DavinciModel {
uint32_t GetBatchNum() const { return batch_num_; } uint32_t GetBatchNum() const { return batch_num_; }
uint32_t GetEventNum() const { return event_num_; } uint32_t GetEventNum() const { return event_num_; }


const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; }
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; }


int32_t GetPriority() const { return priority_; } int32_t GetPriority() const { return priority_; }


const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; }
@@ -81,7 +81,7 @@ class DavinciModel {


private: private:
std::vector<std::shared_ptr<TaskInfo>> task_info_list_; std::vector<std::shared_ptr<TaskInfo>> task_info_list_;
std::vector<std::shared_ptr<OpInfo>> data_info_list_; /*lint !e151*/
std::vector<std::shared_ptr<OpInfo>> data_info_list_;
std::vector<std::shared_ptr<OpInfo>> output_info_list_; std::vector<std::shared_ptr<OpInfo>> output_info_list_;
std::vector<std::shared_ptr<OpInfo>> constant_info_list_; std::vector<std::shared_ptr<OpInfo>> constant_info_list_;
std::vector<model_runner::OpInfoPtr> variable_info_list_; std::vector<model_runner::OpInfoPtr> variable_info_list_;


+ 2
- 5
inc/framework/ge_runtime/model_runner.h View File

@@ -52,11 +52,8 @@ class ModelRunner {


bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


bool GetInputOutputDescInfo(uint32_t model_id,
bool zero_copy,
std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc,
std::vector<uint32_t> *input_format,
bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format,
std::vector<uint32_t> *output_format); std::vector<uint32_t> *output_format);


private: private:


+ 7
- 4
inc/framework/ge_runtime/task_info.h View File

@@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo {
class AicpuTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo {
public: public:
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag) const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name), so_name_(so_name),
kernel_name_(kernel_name), kernel_name_(kernel_name),
node_def_(node_def), node_def_(node_def),
ext_info_(ext_info),
input_data_addrs_(input_data_addrs), input_data_addrs_(input_data_addrs),
output_data_addrs_(output_data_addrs) {} output_data_addrs_(output_data_addrs) {}
~AicpuTaskInfo() override {} ~AicpuTaskInfo() override {}
@@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo {
const std::string &node_def() const { return node_def_; } const std::string &node_def() const { return node_def_; }
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
const std::string &ext_info() const { return ext_info_; }


private: private:
std::string so_name_; std::string so_name_;
std::string kernel_name_; std::string kernel_name_;
std::string node_def_; std::string node_def_;
std::string ext_info_;
std::vector<void *> input_data_addrs_; std::vector<void *> input_data_addrs_;
std::vector<void *> output_data_addrs_; std::vector<void *> output_data_addrs_;
}; };
@@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo {
hcom_distribute_task_(hcom_distribute_task) {} hcom_distribute_task_(hcom_distribute_task) {}
~HcclTaskInfo() override {} ~HcclTaskInfo() override {}


const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/
const std::string &hccl_type() const { return hccl_type_; }
void *input_data_addr() const { return input_data_addr_; } void *input_data_addr() const { return input_data_addr_; }
void *output_data_addr() const { return output_data_addr_; } void *output_data_addr() const { return output_data_addr_; }
void *workspace_addr() const { return workspace_addr_; } void *workspace_addr() const { return workspace_addr_; }
int64_t workspace_size() const { return workspace_size_; } int64_t workspace_size() const { return workspace_size_; }
int64_t hccl_stream_num() const { return hccl_stream_num_; } int64_t hccl_stream_num() const { return hccl_stream_num_; }
const std::vector<uint8_t> &private_def() const { return private_def_; } /*lint !e1413*/
const std::vector<uint8_t> &private_def() const { return private_def_; }
void *ops_kernel_store() const { return ops_kernel_store_; } void *ops_kernel_store() const { return ops_kernel_store_; }
int32_t count() const { return count_; } int32_t count() const { return count_; }
int64_t root_id() const { return root_id_; } int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; } int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; } int64_t data_type() const { return data_type_; }
const std::string group() const { return group_; }
const std::string &group() const { return group_; }
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {


Loading…
Cancel
Save