@@ -8,6 +8,19 @@ if (NOT BUILD_PATH) | |||||
set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | ||||
endif() | endif() | ||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||||
else() | |||||
set(ASCEND_DIR /usr/local/Ascend) | |||||
endif() | |||||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||||
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||||
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||||
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||||
option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | ||||
if (ENABLE_OPEN_SRC) | if (ENABLE_OPEN_SRC) | ||||
@@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC) | |||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | ||||
endif() | endif() | ||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | ||||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||||
find_module(slog libslog.so ${GE_LIB_PATH}) | find_module(slog libslog.so ${GE_LIB_PATH}) | ||||
find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | ||||
find_module(msprof libmsprof.so ${GE_LIB_PATH}) | find_module(msprof libmsprof.so ${GE_LIB_PATH}) | ||||
@@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC) | |||||
find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | ||||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
else() | else() | ||||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||||
else() | |||||
set(ASCEND_DIR /usr/local/Ascend) | |||||
endif() | |||||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||||
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||||
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||||
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||||
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||||
find_module(slog libslog.so ${ASCEND_ATC_DIR}) | find_module(slog libslog.so ${ASCEND_ATC_DIR}) | ||||
find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | ||||
if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
@@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC) | |||||
add_subdirectory(parser) | add_subdirectory(parser) | ||||
#add_subdirectory(metadef/graph) | #add_subdirectory(metadef/graph) | ||||
#add_subdirectory(metadef/register) | #add_subdirectory(metadef/register) | ||||
elseif (ENABLE_D OR ENABLE_ACL) | |||||
# compiling with MindSpore | |||||
include(cmake/external_libs/protobuf_static.cmake) | |||||
include(cmake/external_libs/protoc.cmake) | |||||
include(cmake/external_libs/securec.cmake) | |||||
include(cmake/external_libs/json.cmake) | |||||
include(cmake/FindModule.cmake) | |||||
include(cmake/intf_pub_linux.cmake) | |||||
# common modules | |||||
find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
if (ENABLE_D) | |||||
# training | |||||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||||
find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||||
find_module(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||||
find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||||
elseif(ENABLE_ACL) | |||||
# inference | |||||
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||||
find_module(resource libresource.so ${ASCEND_ATC_DIR}) | |||||
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||||
endif () | |||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||||
add_subdirectory(metadef) | |||||
else() | else() | ||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | ||||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | ||||
@@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES | |||||
add_library(ascend_protobuf_static INTERFACE) | add_library(ascend_protobuf_static INTERFACE) | ||||
target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | ||||
target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | ||||
if (ENABLE_D OR ENABLE_ACL) | |||||
include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) | |||||
endif () | |||||
add_dependencies(ascend_protobuf_static protobuf_static_build) | add_dependencies(ascend_protobuf_static protobuf_static_build) |
@@ -1,10 +1,15 @@ | |||||
add_subdirectory(common) | |||||
add_subdirectory(plugin/engine) | |||||
add_subdirectory(graph/build/memory) | |||||
add_subdirectory(ge_local_engine) | |||||
add_subdirectory(host_cpu_engine) | |||||
add_subdirectory(executor) | |||||
add_subdirectory(offline) | |||||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
add_subdirectory(common) | |||||
add_subdirectory(plugin/engine) | |||||
add_subdirectory(graph/build/memory) | |||||
add_subdirectory(ge_local_engine) | |||||
add_subdirectory(host_cpu_engine) | |||||
add_subdirectory(executor) | |||||
add_subdirectory(offline) | |||||
else() | |||||
add_subdirectory(common) | |||||
add_subdirectory(ge_runtime) | |||||
endif () | |||||
set(PROTO_LIST | set(PROTO_LIST | ||||
"${METADEF_DIR}/proto/fusion_model.proto" | "${METADEF_DIR}/proto/fusion_model.proto" | ||||
@@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | ||||
protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | ||||
############ libge_runner.so ############ | |||||
set(TRAIN_SRC_LIST | set(TRAIN_SRC_LIST | ||||
"common/formats/format_transfers/datatype_transfer.cc" | "common/formats/format_transfers/datatype_transfer.cc" | ||||
"common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | ||||
@@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST | |||||
"ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
) | ) | ||||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
target_compile_definitions(ge_runner PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
DAVINCI_SUPPORT_PROFILING | |||||
REUSE_MEMORY=1 | |||||
FMK_SUPPORT_DUMP | |||||
DAVINCI_CLOUD | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(ge_runner PRIVATE | |||||
-O2 | |||||
) | |||||
target_include_directories(ge_runner PRIVATE | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/ge/analyzer | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${GE_CODE_DIR}/inc/framework/common | |||||
${METADEF_DIR} | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external/graph | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
${GE_CODE_DIR}/../inc/external | |||||
${GE_CODE_DIR}/../inc/cce | |||||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||||
#### blue zone | |||||
${ASCEND_DIR}/driver/include | |||||
${ASCEND_DIR}/fwkacllib/include | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
) | |||||
target_link_libraries(ge_runner | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
ge_memory | |||||
adump_server | |||||
msprofiler | |||||
-Wl,--no-as-needed | |||||
graph | |||||
ge_common | |||||
ascend_protobuf | |||||
register | |||||
c_sec | |||||
slog | |||||
mmpa | |||||
msprof | |||||
runtime | |||||
resource | |||||
error_manager | |||||
ascend_hal_stub | |||||
-Wl,--as-needed | |||||
json | |||||
-lrt | |||||
-ldl | |||||
) | |||||
############ libge_compiler.so ############ | |||||
set(INFER_SRC_LIST | set(INFER_SRC_LIST | ||||
"graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
"omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
@@ -662,6 +600,74 @@ set(INFER_SRC_LIST | |||||
"analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
) | ) | ||||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
############ libge_runner.so ############ | |||||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||||
target_compile_definitions(ge_runner PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
DAVINCI_SUPPORT_PROFILING | |||||
REUSE_MEMORY=1 | |||||
FMK_SUPPORT_DUMP | |||||
DAVINCI_CLOUD | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(ge_runner PRIVATE | |||||
-O2 | |||||
) | |||||
target_include_directories(ge_runner PRIVATE | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/ge/analyzer | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${GE_CODE_DIR}/inc/framework/common | |||||
${METADEF_DIR} | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external/graph | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
#### yellow zone #### | |||||
${GE_CODE_DIR}/../inc | |||||
${GE_CODE_DIR}/../inc/external | |||||
${GE_CODE_DIR}/../inc/cce | |||||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||||
#### blue zone | |||||
${ASCEND_DIR}/driver/include | |||||
${ASCEND_DIR}/fwkacllib/include | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
) | |||||
target_link_libraries(ge_runner | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
ge_memory | |||||
adump_server | |||||
msprofiler | |||||
-Wl,--no-as-needed | |||||
graph | |||||
ge_common | |||||
ascend_protobuf | |||||
register | |||||
c_sec | |||||
slog | |||||
mmpa | |||||
msprof | |||||
runtime | |||||
resource | |||||
error_manager | |||||
ascend_hal_stub | |||||
-Wl,--as-needed | |||||
json | |||||
-lrt | |||||
-ldl | |||||
) | |||||
############ libge_compiler.so ############ | |||||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | ||||
target_compile_definitions(ge_compiler PRIVATE | target_compile_definitions(ge_compiler PRIVATE | ||||
@@ -919,3 +925,70 @@ install(FILES | |||||
${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | ${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | ||||
DESTINATION ${INSTALL_LIBRARY_DIR} | DESTINATION ${INSTALL_LIBRARY_DIR} | ||||
) | ) | ||||
elseif (ENABLE_ACL) | |||||
############ libge_compiler.so w/static protobuf ############ | |||||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||||
target_compile_definitions(ge_compiler PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
REUSE_MEMORY=1 | |||||
FMK_SUPPORT_DUMP | |||||
FMK_HOST_INFER | |||||
COMPILE_OMG_PACKAGE | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(ge_compiler PRIVATE | |||||
-O2 | |||||
) | |||||
target_include_directories(ge_compiler PRIVATE | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/ge/analyzer | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${GE_CODE_DIR}/inc/framework/common | |||||
${METADEF_DIR} | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external/graph | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
${ASCEND_DIR}/driver/include | |||||
${ASCEND_DIR}/fwkacllib/include | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
) | |||||
target_link_libraries(ge_compiler | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
ge_memory | |||||
-Wl,--no-as-needed | |||||
graph | |||||
ge_common | |||||
static_ascend_protobuf | |||||
register | |||||
c_sec | |||||
error_manager | |||||
slog | |||||
mmpa | |||||
runtime_compile | |||||
resource | |||||
-Wl,--as-needed | |||||
json | |||||
-lrt | |||||
-ldl | |||||
) | |||||
############ install libge_compiler for MindSpore############ | |||||
set(INSTALL_BASE_DIR "") | |||||
set(INSTALL_LIBRARY_DIR lib) | |||||
install(TARGETS ge_compiler OPTIONAL | |||||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||||
) | |||||
endif() |
@@ -63,6 +63,7 @@ set(SRC_LIST | |||||
"ge/tbe_plugin_manager.cc" | "ge/tbe_plugin_manager.cc" | ||||
) | ) | ||||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||||
############ libge_common.so ############ | ############ libge_common.so ############ | ||||
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | ||||
target_compile_definitions(ge_common PRIVATE | target_compile_definitions(ge_common PRIVATE | ||||
@@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE | |||||
-ldl | -ldl | ||||
) | ) | ||||
else () | |||||
############ libge_common.so w/static protobuf ############ | |||||
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||||
target_compile_definitions(ge_common PRIVATE | |||||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
HOST_VISIBILITY | |||||
FMK_SUPPORT_DUMP | |||||
OS_CENTOS | |||||
google=ascend_private | |||||
) | |||||
target_compile_options(ge_common PRIVATE | |||||
-fvisibility=hidden | |||||
-O2 | |||||
-Werror | |||||
) | |||||
target_include_directories(ge_common PRIVATE | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/ge/common | |||||
${GE_CODE_DIR}/ge/common/op | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/framework | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/external/graph | |||||
${METADEF_DIR}/inc/graph | |||||
${CMAKE_BINARY_DIR} | |||||
${CMAKE_BINARY_DIR}/proto/ge | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||||
) | |||||
target_link_libraries(ge_common PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
ascend_protobuf_static | |||||
-Wl,--no-as-needed | |||||
graph | |||||
register | |||||
c_sec | |||||
error_manager | |||||
slog | |||||
mmpa | |||||
-Wl,--as-needed | |||||
json | |||||
-lrt | |||||
-ldl | |||||
) | |||||
endif () | |||||
############ install ############ | ############ install ############ | ||||
set(INSTALL_BASE_DIR "") | set(INSTALL_BASE_DIR "") | ||||
set(INSTALL_LIBRARY_DIR lib) | set(INSTALL_LIBRARY_DIR lib) | ||||
@@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE | |||||
) | ) | ||||
target_include_directories(ge_runtime PRIVATE | target_include_directories(ge_runtime PRIVATE | ||||
${TOP_DIR} | |||||
${TOP_DIR}/inc | |||||
${TOP_DIR}/inc/graph | |||||
${TOP_DIR}/inc/external | |||||
${TOP_DIR}/inc/framework | |||||
${TOP_DIR}/inc/framework/common | |||||
${TOP_DIR}/inc/framework/ge_runtime | |||||
${TOP_DIR}/inc/cce | |||||
${CMAKE_CURRENT_LIST_DIR} | |||||
${GE_CODE_DIR} | |||||
${GE_CODE_DIR}/ge | |||||
${GE_CODE_DIR}/inc | |||||
${GE_CODE_DIR}/inc/graph | |||||
${GE_CODE_DIR}/inc/external | |||||
${GE_CODE_DIR}/inc/framework | |||||
${GE_CODE_DIR}/inc/framework/common | |||||
${GE_CODE_DIR}/inc/framework/ge_runtime | |||||
${GE_CODE_DIR}/inc/cce | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR} | |||||
${METADEF_DIR}/inc | |||||
${METADEF_DIR}/inc/external/graph | |||||
${METADEF_DIR}/inc/external | |||||
${METADEF_DIR}/inc/graph | |||||
${CMAKE_BINARY_DIR} | ${CMAKE_BINARY_DIR} | ||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
) | ) | ||||
@@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE | |||||
slog | slog | ||||
runtime | runtime | ||||
c_sec | c_sec | ||||
graph | |||||
-Wl,--as-needed | -Wl,--as-needed | ||||
-lrt | -lrt | ||||
-ldl | -ldl | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -27,8 +27,13 @@ class ModelContext { | |||||
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | ||||
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | ||||
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | ||||
: device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle), | |||||
rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list), | |||||
: device_id_(device_id), | |||||
session_id_(session_id), | |||||
priority_(priority), | |||||
rt_model_handle_(rt_model_handle), | |||||
rt_model_stream_(rt_model_stream), | |||||
stream_list_(stream_list), | |||||
label_list_(label_list), | |||||
event_list_(event_list) {} | event_list_(event_list) {} | ||||
~ModelContext() {} | ~ModelContext() {} | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,6 +24,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | ||||
using DavinciModelPtr = std::shared_ptr<DavinciModel>; | using DavinciModelPtr = std::shared_ptr<DavinciModel>; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat | |||||
bool support_mem_share) { | bool support_mem_share) { | ||||
return true; | return true; | ||||
} | } | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,6 +24,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace model_runner { | namespace model_runner { | ||||
class Output { | class Output { | ||||
public: | public: | ||||
Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | ||||
@@ -32,8 +33,7 @@ class Output { | |||||
bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | ||||
bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||||
bool support_mem_share); | |||||
bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||||
// Copy assignment operator and copy constructor are deleted | // Copy assignment operator and copy constructor are deleted | ||||
Output &operator=(const Output &output) = delete; | Output &operator=(const Output &output) = delete; | ||||
@@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | ||||
rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | ||||
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||||
: (RT_STREAM_PERSISTENT); | |||||
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||||
: (RT_STREAM_PERSISTENT); | |||||
rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
@@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||||
return true; | return true; | ||||
} | } | ||||
bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||||
GELOGI("batch number:%u.", batch_num); | |||||
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||||
rtLabel_t rt_lLabel = nullptr; | |||||
rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||||
return false; | |||||
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||||
label_list_.resize(davinci_model->GetBatchNum()); | |||||
for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||||
if (task_info == nullptr) { | |||||
GELOGE(PARAM_INVALID, "task_info is null."); | |||||
continue; | |||||
} | |||||
if (task_info->type() != TaskInfoType::LABEL_SET) { | |||||
continue; | |||||
} | } | ||||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||||
if (rt_lLabel == nullptr) { | |||||
GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||||
if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||||
GELOGE(PARAM_INVALID, "Invalid stream id."); | |||||
return false; | return false; | ||||
} | } | ||||
label_list_.emplace_back(rt_lLabel); | |||||
rtLabel_t rt_label = nullptr; | |||||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
label_list_[label_set_task_info->label_id()] = rt_label; | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
return false; | return false; | ||||
} | } | ||||
if (!InitLabel(davinci_model->GetBatchNum())) { | |||||
if (!InitLabel(davinci_model)) { | |||||
return false; | return false; | ||||
} | } | ||||
@@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() { | |||||
GELOGE(FAILED, "DistributeTask failed"); | GELOGE(FAILED, "DistributeTask failed"); | ||||
return false; | return false; | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -293,10 +303,14 @@ bool RuntimeModel::Run() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("Run rtModelExecute success"); | |||||
GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||||
ret = rtStreamSynchronize(rt_model_stream_); | ret = rtStreamSynchronize(rt_model_stream_); | ||||
if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||||
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||||
return true; | |||||
} | |||||
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept { | |||||
void RuntimeModel::RtLabelDestory() noexcept { | void RuntimeModel::RtLabelDestory() noexcept { | ||||
for (size_t i = 0; i < label_list_.size(); i++) { | for (size_t i = 0; i < label_list_.size(); i++) { | ||||
if (label_list_[i] == nullptr) { | |||||
continue; | |||||
} | |||||
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | ||||
return; | return; | ||||
@@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
/// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
elem_num = 1; | |||||
} | |||||
int64_t elem_num = | |||||
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
return false; | return false; | ||||
@@ -40,13 +40,11 @@ class RuntimeModel { | |||||
const std::vector<uint32_t> &GetTaskIdList() const; | const std::vector<uint32_t> &GetTaskIdList() const; | ||||
const std::vector<uint32_t> &GetStreamIdList() const; | const std::vector<uint32_t> &GetStreamIdList() const; | ||||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | ||||
const rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||||
bool Run(); | bool Run(); | ||||
bool CopyInputData(const InputData &input_data); | bool CopyInputData(const InputData &input_data); | ||||
bool GetInputOutputDescInfo(bool zero_copy, | |||||
std::vector<InputOutputDescInfo> *input_desc, | |||||
std::vector<InputOutputDescInfo> *output_desc, | |||||
std::vector<uint32_t> *input_format, | |||||
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
std::vector<uint32_t> *output_format); | std::vector<uint32_t> *output_format); | ||||
private: | private: | ||||
@@ -55,7 +53,7 @@ class RuntimeModel { | |||||
bool LoadTask(); | bool LoadTask(); | ||||
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitEvent(uint32_t event_num); | bool InitEvent(uint32_t event_num); | ||||
bool InitLabel(uint32_t batch_num); | |||||
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||||
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | ||||
@@ -87,6 +85,7 @@ class RuntimeModel { | |||||
std::vector<uint32_t> stream_id_list_{}; | std::vector<uint32_t> stream_id_list_{}; | ||||
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | ||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
@@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||||
task_info_(task_info), | task_info_(task_info), | ||||
stream_(nullptr), | stream_(nullptr), | ||||
args_(nullptr), | args_(nullptr), | ||||
ext_info_(nullptr), | |||||
input_output_addr_(nullptr) { | input_output_addr_(nullptr) { | ||||
if (task_info_ == nullptr) { | if (task_info_ == nullptr) { | ||||
GELOGW("task_info_ is null!"); | GELOGW("task_info_ is null!"); | ||||
@@ -41,7 +42,10 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||||
} | } | ||||
} | } | ||||
AicpuTask::~AicpuTask() { ReleaseRtMem(&args_); } | |||||
AicpuTask::~AicpuTask() { | |||||
ReleaseRtMem(&args_); | |||||
ReleaseRtMem(&ext_info_); | |||||
} | |||||
bool AicpuTask::Distribute() { | bool AicpuTask::Distribute() { | ||||
GELOGI("InitAicpuTask start."); | GELOGI("InitAicpuTask start."); | ||||
@@ -51,10 +55,37 @@ bool AicpuTask::Distribute() { | |||||
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | ||||
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | ||||
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | ||||
uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; | |||||
uint32_t args_size = | |||||
sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size()); | |||||
aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; | |||||
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||||
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||||
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||||
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||||
aicpu::AicpuParamHead aicpu_param_head; | |||||
aicpu_param_head.length = args_size; | |||||
aicpu_param_head.ioAddrNum = io_addrs_num; | |||||
auto ext_info = task_info_->ext_info(); | |||||
uint32_t ext_size = ext_info.size(); | |||||
if (ext_info.empty()) { | |||||
aicpu_param_head.extInfoLength = 0; | |||||
aicpu_param_head.extInfoAddr = 0; | |||||
} else { | |||||
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||||
if (flag != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||||
return false; | |||||
} | |||||
flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||||
RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (flag != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||||
return false; | |||||
} | |||||
GELOGI("ext info size:", ext_size); | |||||
aicpu_param_head.extInfoLength = ext_size; | |||||
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||||
} | |||||
// Malloc device memory for args | // Malloc device memory for args | ||||
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | ||||
@@ -80,6 +111,17 @@ bool AicpuTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
// Memcpy node def | |||||
auto size = task_info_->node_def().size(); | |||||
rt_ret = | |||||
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||||
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||||
return false; | |||||
} | |||||
// Memcpy node def | // Memcpy node def | ||||
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | ||||
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | ||||
@@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||||
std::shared_ptr<AicpuTaskInfo> task_info_; | std::shared_ptr<AicpuTaskInfo> task_info_; | ||||
void *stream_; | void *stream_; | ||||
void *args_; | void *args_; | ||||
void *ext_info_; | |||||
void *input_output_addr_; | void *input_output_addr_; | ||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
@@ -103,9 +103,9 @@ bool CceTask::Distribute() { | |||||
// Modify flowtable addr in args | // Modify flowtable addr in args | ||||
auto args = const_cast<uint8_t *>(task_info_->args().data()); | auto args = const_cast<uint8_t *>(task_info_->args().data()); | ||||
auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | ||||
if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | ||||
GELOGE(FAILED, | |||||
"(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||||
GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||||
static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -136,8 +136,7 @@ bool CceTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), | |||||
task_info_->sm_desc().data(), | |||||
rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||||
task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
@@ -146,12 +145,8 @@ bool CceTask::Distribute() { | |||||
} | } | ||||
// Kernel launch | // Kernel launch | ||||
rt_ret = rtKernelLaunch(stub_func_, | |||||
task_info_->block_dim(), | |||||
args_, | |||||
task_info_->args_size(), | |||||
static_cast<rtSmDesc_t *>(sm_desc_), | |||||
stream_); | |||||
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||||
static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return false; | return false; | ||||
@@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||||
private: | private: | ||||
std::shared_ptr<EventRecordTaskInfo> task_info_; | std::shared_ptr<EventRecordTaskInfo> task_info_; | ||||
rtStream_t stream_; | rtStream_t stream_; | ||||
rtEvent_t event_; | |||||
rtEvent_t event_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||||
private: | private: | ||||
std::shared_ptr<EventWaitTaskInfo> task_info_; | std::shared_ptr<EventWaitTaskInfo> task_info_; | ||||
rtStream_t stream_; | rtStream_t stream_; | ||||
rtEvent_t event_; | |||||
rtEvent_t event_; | |||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
@@ -115,7 +115,6 @@ bool HcclTask::Distribute() { | |||||
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
(void)rtStreamDestroy(stream); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -129,8 +128,6 @@ bool HcclTask::Distribute() { | |||||
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ||||
ge_task.stream = stream_; | ge_task.stream = stream_; | ||||
GETaskKernelHcclInfo kernel_hccl_info; | |||||
ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info); | |||||
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | ||||
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | ||||
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | ||||
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_goto_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelGotoTask::~LabelGotoTask() {} | |||||
bool LabelGotoTask::Distribute() { | |||||
GELOGI("LabelGotoTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||||
public: | |||||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||||
~LabelGotoTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ |
@@ -0,0 +1,70 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_set_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
label_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
auto stream_list = model_context.stream_list(); | |||||
auto label_list = model_context.label_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
uint32_t label_id = task_info->label_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||||
GELOGW("Stream/Label id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
label_ = label_list[label_id]; | |||||
} | |||||
LabelSetTask::~LabelSetTask() {} | |||||
bool LabelSetTask::Distribute() { | |||||
GELOGI("LabelSetTask Distribute start."); | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (label_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "label is null!"); | |||||
return false; | |||||
} | |||||
rtError_t rt_ret = rtLabelSet(label_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||||
public: | |||||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||||
~LabelSetTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
std::shared_ptr<LabelSetTaskInfo> task_info_; | |||||
void *stream_; | |||||
void *label_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ |
@@ -0,0 +1,131 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_runtime/task/label_switch_task.h" | |||||
#include "ge_runtime/task/task_factory.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||||
task_info_(task_info), | |||||
stream_(nullptr), | |||||
all_label_resource_(), | |||||
label_info_(nullptr) { | |||||
if (task_info_ == nullptr) { | |||||
GELOGW("task_info_ is null!"); | |||||
return; | |||||
} | |||||
all_label_resource_ = model_context.label_list(); | |||||
auto stream_list = model_context.stream_list(); | |||||
uint32_t stream_id = task_info->stream_id(); | |||||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||||
if (stream_id >= stream_list.size()) { | |||||
GELOGW("Stream id invalid."); | |||||
return; | |||||
} | |||||
stream_ = stream_list[stream_id]; | |||||
} | |||||
LabelSwitchTask::~LabelSwitchTask() { | |||||
if (label_info_ != nullptr) { | |||||
rtError_t rt_ret = rtFree(label_info_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||||
} | |||||
label_info_ = nullptr; | |||||
} | |||||
} | |||||
bool LabelSwitchTask::Distribute() { | |||||
GELOGI("LabelSwitchTask Distribute start."); | |||||
if (!CheckParamValid()) { | |||||
return false; | |||||
} | |||||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||||
uint32_t label_index = label_index_list[i]; | |||||
if (label_index >= all_label_resource_.size()) { | |||||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||||
all_label_resource_.size()); | |||||
return false; | |||||
} | |||||
label_list[i] = all_label_resource_[label_index]; | |||||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||||
} | |||||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return false; | |||||
} | |||||
GELOGI("DistributeTask end."); | |||||
return true; | |||||
} | |||||
bool LabelSwitchTask::CheckParamValid() { | |||||
if (stream_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "stream is null!"); | |||||
return false; | |||||
} | |||||
if (task_info_->label_list().empty()) { | |||||
GELOGE(PARAM_INVALID, "label_list is empty."); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() != task_info_->label_list().size()) { | |||||
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||||
task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||||
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||||
return false; | |||||
} | |||||
if (label_info_ != nullptr) { | |||||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||||
} // namespace model_runner | |||||
} // namespace ge |
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||||
#include <memory> | |||||
#include "ge_runtime/task/task.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||||
public: | |||||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||||
~LabelSwitchTask() override; | |||||
bool Distribute() override; | |||||
private: | |||||
bool CheckParamValid(); | |||||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||||
void *stream_; | |||||
std::vector<void *> all_label_resource_; | |||||
void *label_info_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ |
@@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||||
void *stream_; | void *stream_; | ||||
std::vector<rtStream_t> stream_list_; | std::vector<rtStream_t> stream_list_; | ||||
}; | }; | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ | #endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ |
@@ -42,7 +42,7 @@ class Task { | |||||
template <class T> | template <class T> | ||||
class TaskRepeater : public Task { | class TaskRepeater : public Task { | ||||
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); /*lint !e30*/ | |||||
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||||
public: | public: | ||||
TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | ||||
@@ -81,6 +81,7 @@ class TaskFactory { | |||||
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | ||||
return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | ||||
}); | }); | ||||
} // namespace model_runner | } // namespace model_runner | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ | #endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ |
@@ -27,10 +27,10 @@ namespace ge { | |||||
namespace model_runner { | namespace model_runner { | ||||
class DavinciModel { | class DavinciModel { | ||||
public: | public: | ||||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, /*lint !e151*/ | |||||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||||
const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | ||||
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, /*lint !e151*/ | |||||
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, /*lint !e1049*/ | |||||
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||||
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||||
const std::vector<model_runner::OpInfoPtr> &variable_info_list, | const std::vector<model_runner::OpInfoPtr> &variable_info_list, | ||||
const std::vector<uint32_t> &wait_active_stream_list, | const std::vector<uint32_t> &wait_active_stream_list, | ||||
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | ||||
@@ -68,12 +68,12 @@ class DavinciModel { | |||||
uint32_t GetBatchNum() const { return batch_num_; } | uint32_t GetBatchNum() const { return batch_num_; } | ||||
uint32_t GetEventNum() const { return event_num_; } | uint32_t GetEventNum() const { return event_num_; } | ||||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/ | |||||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/ | |||||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||||
int32_t GetPriority() const { return priority_; } | int32_t GetPriority() const { return priority_; } | ||||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/ | |||||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||||
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | ||||
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | ||||
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | ||||
@@ -81,7 +81,7 @@ class DavinciModel { | |||||
private: | private: | ||||
std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | ||||
std::vector<std::shared_ptr<OpInfo>> data_info_list_; /*lint !e151*/ | |||||
std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||||
std::vector<std::shared_ptr<OpInfo>> output_info_list_; | std::vector<std::shared_ptr<OpInfo>> output_info_list_; | ||||
std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | ||||
std::vector<model_runner::OpInfoPtr> variable_info_list_; | std::vector<model_runner::OpInfoPtr> variable_info_list_; | ||||
@@ -52,11 +52,8 @@ class ModelRunner { | |||||
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | ||||
bool GetInputOutputDescInfo(uint32_t model_id, | |||||
bool zero_copy, | |||||
std::vector<InputOutputDescInfo> *input_desc, | |||||
std::vector<InputOutputDescInfo> *output_desc, | |||||
std::vector<uint32_t> *input_format, | |||||
bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
std::vector<uint32_t> *output_format); | std::vector<uint32_t> *output_format); | ||||
private: | private: | ||||
@@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { | |||||
class AicpuTaskInfo : public TaskInfo { | class AicpuTaskInfo : public TaskInfo { | ||||
public: | public: | ||||
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | ||||
const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||||
const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||||
const std::vector<void *> &output_data_addrs, bool dump_flag) | const std::vector<void *> &output_data_addrs, bool dump_flag) | ||||
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | ||||
so_name_(so_name), | so_name_(so_name), | ||||
kernel_name_(kernel_name), | kernel_name_(kernel_name), | ||||
node_def_(node_def), | node_def_(node_def), | ||||
ext_info_(ext_info), | |||||
input_data_addrs_(input_data_addrs), | input_data_addrs_(input_data_addrs), | ||||
output_data_addrs_(output_data_addrs) {} | output_data_addrs_(output_data_addrs) {} | ||||
~AicpuTaskInfo() override {} | ~AicpuTaskInfo() override {} | ||||
@@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { | |||||
const std::string &node_def() const { return node_def_; } | const std::string &node_def() const { return node_def_; } | ||||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | ||||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | ||||
const std::string &ext_info() const { return ext_info_; } | |||||
private: | private: | ||||
std::string so_name_; | std::string so_name_; | ||||
std::string kernel_name_; | std::string kernel_name_; | ||||
std::string node_def_; | std::string node_def_; | ||||
std::string ext_info_; | |||||
std::vector<void *> input_data_addrs_; | std::vector<void *> input_data_addrs_; | ||||
std::vector<void *> output_data_addrs_; | std::vector<void *> output_data_addrs_; | ||||
}; | }; | ||||
@@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo { | |||||
hcom_distribute_task_(hcom_distribute_task) {} | hcom_distribute_task_(hcom_distribute_task) {} | ||||
~HcclTaskInfo() override {} | ~HcclTaskInfo() override {} | ||||
const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/ | |||||
const std::string &hccl_type() const { return hccl_type_; } | |||||
void *input_data_addr() const { return input_data_addr_; } | void *input_data_addr() const { return input_data_addr_; } | ||||
void *output_data_addr() const { return output_data_addr_; } | void *output_data_addr() const { return output_data_addr_; } | ||||
void *workspace_addr() const { return workspace_addr_; } | void *workspace_addr() const { return workspace_addr_; } | ||||
int64_t workspace_size() const { return workspace_size_; } | int64_t workspace_size() const { return workspace_size_; } | ||||
int64_t hccl_stream_num() const { return hccl_stream_num_; } | int64_t hccl_stream_num() const { return hccl_stream_num_; } | ||||
const std::vector<uint8_t> &private_def() const { return private_def_; } /*lint !e1413*/ | |||||
const std::vector<uint8_t> &private_def() const { return private_def_; } | |||||
void *ops_kernel_store() const { return ops_kernel_store_; } | void *ops_kernel_store() const { return ops_kernel_store_; } | ||||
int32_t count() const { return count_; } | int32_t count() const { return count_; } | ||||
int64_t root_id() const { return root_id_; } | int64_t root_id() const { return root_id_; } | ||||
int64_t op_type() const { return op_type_; } | int64_t op_type() const { return op_type_; } | ||||
int64_t data_type() const { return data_type_; } | int64_t data_type() const { return data_type_; } | ||||
const std::string group() const { return group_; } | |||||
const std::string &group() const { return group_; } | |||||
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | ||||
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | ||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | ||||