@@ -8,6 +8,19 @@ if (NOT BUILD_PATH) | |||
set(BUILD_PATH "${CMAKE_SOURCE_DIR}/build") | |||
endif() | |||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||
else() | |||
set(ASCEND_DIR /usr/local/Ascend) | |||
endif() | |||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||
option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) | |||
if (ENABLE_OPEN_SRC) | |||
@@ -41,7 +54,7 @@ if (ENABLE_OPEN_SRC) | |||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | |||
endif() | |||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | |||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
set(STATIC_ACL_LIB ${GE_LIB_PATH}) | |||
find_module(slog libslog.so ${GE_LIB_PATH}) | |||
find_module(mmpa libmmpa.so ${GE_LIB_PATH}) | |||
find_module(msprof libmsprof.so ${GE_LIB_PATH}) | |||
@@ -56,18 +69,6 @@ if (ENABLE_OPEN_SRC) | |||
find_module(msprofiler libmsprofiler.a ${GE_LIB_PATH}) | |||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
else() | |||
if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH}) | |||
else() | |||
set(ASCEND_DIR /usr/local/Ascend) | |||
endif() | |||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64) | |||
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common) | |||
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share) | |||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | |||
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64) | |||
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64) | |||
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR}) | |||
find_module(slog libslog.so ${ASCEND_ATC_DIR}) | |||
find_module(mmpa libmmpa.so ${ASCEND_ATC_DIR}) | |||
if(PLATFORM STREQUAL "train") | |||
@@ -127,6 +128,36 @@ if (ENABLE_OPEN_SRC) | |||
add_subdirectory(parser) | |||
#add_subdirectory(metadef/graph) | |||
#add_subdirectory(metadef/register) | |||
elseif (ENABLE_D OR ENABLE_ACL) | |||
# compiling with MindSpore | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
include(cmake/intf_pub_linux.cmake) | |||
# common modules | |||
find_module(slog libslog.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
find_module(mmpa libmmpa.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
if (ENABLE_D) | |||
# training | |||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||
elseif(ENABLE_ACL) | |||
# inference | |||
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(resource libresource.so ${ASCEND_ATC_DIR}) | |||
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||
endif () | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
add_subdirectory(metadef) | |||
else() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/../metadef) | |||
set(PARSER_DIR ${CMAKE_CURRENT_LIST_DIR}/../parser) | |||
@@ -48,5 +48,8 @@ set_target_properties(ascend_protobuf_static_lib PROPERTIES | |||
add_library(ascend_protobuf_static INTERFACE) | |||
target_include_directories(ascend_protobuf_static INTERFACE ${PROTOBUF_STATIC_PKG_DIR}/include) | |||
target_link_libraries(ascend_protobuf_static INTERFACE ascend_protobuf_static_lib) | |||
if (ENABLE_D OR ENABLE_ACL) | |||
include_directories(${PROTOBUF_STATIC_PKG_DIR}/include) | |||
endif () | |||
add_dependencies(ascend_protobuf_static protobuf_static_build) |
@@ -1,10 +1,15 @@ | |||
add_subdirectory(common) | |||
add_subdirectory(plugin/engine) | |||
add_subdirectory(graph/build/memory) | |||
add_subdirectory(ge_local_engine) | |||
add_subdirectory(host_cpu_engine) | |||
add_subdirectory(executor) | |||
add_subdirectory(offline) | |||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
add_subdirectory(common) | |||
add_subdirectory(plugin/engine) | |||
add_subdirectory(graph/build/memory) | |||
add_subdirectory(ge_local_engine) | |||
add_subdirectory(host_cpu_engine) | |||
add_subdirectory(executor) | |||
add_subdirectory(offline) | |||
else() | |||
add_subdirectory(common) | |||
add_subdirectory(ge_runtime) | |||
endif () | |||
set(PROTO_LIST | |||
"${METADEF_DIR}/proto/fusion_model.proto" | |||
@@ -28,7 +33,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | |||
protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | |||
############ libge_runner.so ############ | |||
set(TRAIN_SRC_LIST | |||
"common/formats/format_transfers/datatype_transfer.cc" | |||
"common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | |||
@@ -333,72 +337,6 @@ set(TRAIN_SRC_LIST | |||
"ir_build/atc_ir_common.cc" | |||
) | |||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||
target_compile_definitions(ge_runner PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
DAVINCI_SUPPORT_PROFILING | |||
REUSE_MEMORY=1 | |||
FMK_SUPPORT_DUMP | |||
DAVINCI_CLOUD | |||
google=ascend_private | |||
) | |||
target_compile_options(ge_runner PRIVATE | |||
-O2 | |||
) | |||
target_include_directories(ge_runner PRIVATE | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/ge/analyzer | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${GE_CODE_DIR}/inc/framework/common | |||
${METADEF_DIR} | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external/graph | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
${GE_CODE_DIR}/../inc/external | |||
${GE_CODE_DIR}/../inc/cce | |||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
#### blue zone | |||
${ASCEND_DIR}/driver/include | |||
${ASCEND_DIR}/fwkacllib/include | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
) | |||
target_link_libraries(ge_runner | |||
$<BUILD_INTERFACE:intf_pub> | |||
ge_memory | |||
adump_server | |||
msprofiler | |||
-Wl,--no-as-needed | |||
graph | |||
ge_common | |||
ascend_protobuf | |||
register | |||
c_sec | |||
slog | |||
mmpa | |||
msprof | |||
runtime | |||
resource | |||
error_manager | |||
ascend_hal_stub | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
-ldl | |||
) | |||
############ libge_compiler.so ############ | |||
set(INFER_SRC_LIST | |||
"graph/manager/trans_var_data_utils.cc" | |||
"omm/csa_interact.cc" | |||
@@ -662,6 +600,74 @@ set(INFER_SRC_LIST | |||
"analyzer/analyzer.cc" | |||
) | |||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
############ libge_runner.so ############ | |||
add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS}) | |||
target_compile_definitions(ge_runner PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
DAVINCI_SUPPORT_PROFILING | |||
REUSE_MEMORY=1 | |||
FMK_SUPPORT_DUMP | |||
DAVINCI_CLOUD | |||
google=ascend_private | |||
) | |||
target_compile_options(ge_runner PRIVATE | |||
-O2 | |||
) | |||
target_include_directories(ge_runner PRIVATE | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/ge/analyzer | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${GE_CODE_DIR}/inc/framework/common | |||
${METADEF_DIR} | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external/graph | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
#### yellow zone #### | |||
${GE_CODE_DIR}/../inc | |||
${GE_CODE_DIR}/../inc/external | |||
${GE_CODE_DIR}/../inc/cce | |||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | |||
#### blue zone | |||
${ASCEND_DIR}/driver/include | |||
${ASCEND_DIR}/fwkacllib/include | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
) | |||
target_link_libraries(ge_runner | |||
$<BUILD_INTERFACE:intf_pub> | |||
ge_memory | |||
adump_server | |||
msprofiler | |||
-Wl,--no-as-needed | |||
graph | |||
ge_common | |||
ascend_protobuf | |||
register | |||
c_sec | |||
slog | |||
mmpa | |||
msprof | |||
runtime | |||
resource | |||
error_manager | |||
ascend_hal_stub | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
-ldl | |||
) | |||
############ libge_compiler.so ############ | |||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||
target_compile_definitions(ge_compiler PRIVATE | |||
@@ -919,3 +925,70 @@ install(FILES | |||
${CMAKE_CURRENT_BINARY_DIR}/optimizer_priority.pbtxt OPTIONAL | |||
DESTINATION ${INSTALL_LIBRARY_DIR} | |||
) | |||
elseif (ENABLE_ACL) | |||
############ libge_compiler.so w/static protobuf ############ | |||
add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS}) | |||
target_compile_definitions(ge_compiler PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
REUSE_MEMORY=1 | |||
FMK_SUPPORT_DUMP | |||
FMK_HOST_INFER | |||
COMPILE_OMG_PACKAGE | |||
google=ascend_private | |||
) | |||
target_compile_options(ge_compiler PRIVATE | |||
-O2 | |||
) | |||
target_include_directories(ge_compiler PRIVATE | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/ge/analyzer | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${GE_CODE_DIR}/inc/framework/common | |||
${METADEF_DIR} | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external/graph | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
${ASCEND_DIR}/driver/include | |||
${ASCEND_DIR}/fwkacllib/include | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
) | |||
target_link_libraries(ge_compiler | |||
$<BUILD_INTERFACE:intf_pub> | |||
ge_memory | |||
-Wl,--no-as-needed | |||
graph | |||
ge_common | |||
static_ascend_protobuf | |||
register | |||
c_sec | |||
error_manager | |||
slog | |||
mmpa | |||
runtime_compile | |||
resource | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
-ldl | |||
) | |||
############ install libge_compiler for MindSpore############ | |||
set(INSTALL_BASE_DIR "") | |||
set(INSTALL_LIBRARY_DIR lib) | |||
install(TARGETS ge_compiler OPTIONAL | |||
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} | |||
) | |||
endif() |
@@ -63,6 +63,7 @@ set(SRC_LIST | |||
"ge/tbe_plugin_manager.cc" | |||
) | |||
if (NOT ENABLE_D AND NOT ENABLE_ACL) | |||
############ libge_common.so ############ | |||
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||
target_compile_definitions(ge_common PRIVATE | |||
@@ -164,6 +165,57 @@ target_link_libraries(ge_common_static PRIVATE | |||
-ldl | |||
) | |||
else () | |||
############ libge_common.so w/static protobuf ############ | |||
add_library(ge_common SHARED ${SRC_LIST} ${PROTO_HDRS}) | |||
target_compile_definitions(ge_common PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
HOST_VISIBILITY | |||
FMK_SUPPORT_DUMP | |||
OS_CENTOS | |||
google=ascend_private | |||
) | |||
target_compile_options(ge_common PRIVATE | |||
-fvisibility=hidden | |||
-O2 | |||
-Werror | |||
) | |||
target_include_directories(ge_common PRIVATE | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/ge/common | |||
${GE_CODE_DIR}/ge/common/op | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/framework | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/external/graph | |||
${METADEF_DIR}/inc/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | |||
) | |||
target_link_libraries(ge_common PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
ascend_protobuf_static | |||
-Wl,--no-as-needed | |||
graph | |||
register | |||
c_sec | |||
error_manager | |||
slog | |||
mmpa | |||
-Wl,--as-needed | |||
json | |||
-lrt | |||
-ldl | |||
) | |||
endif () | |||
############ install ############ | |||
set(INSTALL_BASE_DIR "") | |||
set(INSTALL_LIBRARY_DIR lib) | |||
@@ -27,14 +27,22 @@ target_compile_definitions(ge_runtime PRIVATE | |||
) | |||
target_include_directories(ge_runtime PRIVATE | |||
${TOP_DIR} | |||
${TOP_DIR}/inc | |||
${TOP_DIR}/inc/graph | |||
${TOP_DIR}/inc/external | |||
${TOP_DIR}/inc/framework | |||
${TOP_DIR}/inc/framework/common | |||
${TOP_DIR}/inc/framework/ge_runtime | |||
${TOP_DIR}/inc/cce | |||
${CMAKE_CURRENT_LIST_DIR} | |||
${GE_CODE_DIR} | |||
${GE_CODE_DIR}/ge | |||
${GE_CODE_DIR}/inc | |||
${GE_CODE_DIR}/inc/graph | |||
${GE_CODE_DIR}/inc/external | |||
${GE_CODE_DIR}/inc/framework | |||
${GE_CODE_DIR}/inc/framework/common | |||
${GE_CODE_DIR}/inc/framework/ge_runtime | |||
${GE_CODE_DIR}/inc/cce | |||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||
${METADEF_DIR} | |||
${METADEF_DIR}/inc | |||
${METADEF_DIR}/inc/external/graph | |||
${METADEF_DIR}/inc/external | |||
${METADEF_DIR}/inc/graph | |||
${CMAKE_BINARY_DIR} | |||
${CMAKE_BINARY_DIR}/proto/ge | |||
) | |||
@@ -45,6 +53,7 @@ target_link_libraries(ge_runtime PRIVATE | |||
slog | |||
runtime | |||
c_sec | |||
graph | |||
-Wl,--as-needed | |||
-lrt | |||
-ldl | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -27,8 +27,13 @@ class ModelContext { | |||
ModelContext(uint32_t device_id, uint64_t session_id, int32_t priority, rtModel_t rt_model_handle, | |||
rtStream_t rt_model_stream, const std::vector<rtStream_t> &stream_list, | |||
const std::vector<rtLabel_t> &label_list, const std::vector<rtEvent_t> &event_list) | |||
: device_id_(device_id), session_id_(session_id), priority_(priority), rt_model_handle_(rt_model_handle), | |||
rt_model_stream_(rt_model_stream), stream_list_(stream_list), label_list_(label_list), | |||
: device_id_(device_id), | |||
session_id_(session_id), | |||
priority_(priority), | |||
rt_model_handle_(rt_model_handle), | |||
rt_model_stream_(rt_model_stream), | |||
stream_list_(stream_list), | |||
label_list_(label_list), | |||
event_list_(event_list) {} | |||
~ModelContext() {} | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,6 +24,7 @@ | |||
namespace ge { | |||
namespace model_runner { | |||
using RuntimeModelPtr = std::shared_ptr<RuntimeModel>; | |||
using DavinciModelPtr = std::shared_ptr<DavinciModel>; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -89,5 +89,6 @@ bool Output::SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &dat | |||
bool support_mem_share) { | |||
return true; | |||
} | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,6 +24,7 @@ | |||
namespace ge { | |||
namespace model_runner { | |||
class Output { | |||
public: | |||
Output(const OpInfoPtr &op_info, const std::shared_ptr<DavinciModel> &model); | |||
@@ -32,8 +33,7 @@ class Output { | |||
bool CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); | |||
bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, | |||
bool support_mem_share); | |||
bool SetDataBuf(DataBuffer &data_buf, uint32_t data_begin, uint32_t &data_count, size_t i, bool support_mem_share); | |||
// Copy assignment operator and copy constructor are deleted | |||
Output &operator=(const Output &output) = delete; | |||
@@ -74,8 +74,8 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||
for (uint32_t i = 0; i < davinci_model->GetStreamNum(); ++i) { | |||
rtStream_t stream = nullptr; | |||
uint32_t flag = (force_copy_streams.find(i) != force_copy_streams.end()) | |||
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
: (RT_STREAM_PERSISTENT); | |||
? (RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY) | |||
: (RT_STREAM_PERSISTENT); | |||
rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->GetPriority(), flag); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
@@ -115,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
return true; | |||
} | |||
bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||
GELOGI("batch number:%u.", batch_num); | |||
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||
rtLabel_t rt_lLabel = nullptr; | |||
rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
return false; | |||
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
label_list_.resize(davinci_model->GetBatchNum()); | |||
for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
if (task_info == nullptr) { | |||
GELOGE(PARAM_INVALID, "task_info is null."); | |||
continue; | |||
} | |||
if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
continue; | |||
} | |||
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
if (rt_lLabel == nullptr) { | |||
GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||
if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
return false; | |||
} | |||
label_list_.emplace_back(rt_lLabel); | |||
rtLabel_t rt_label = nullptr; | |||
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
label_list_[label_set_task_info->label_id()] = rt_label; | |||
} | |||
return true; | |||
} | |||
@@ -163,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
return false; | |||
} | |||
if (!InitLabel(davinci_model->GetBatchNum())) { | |||
if (!InitLabel(davinci_model)) { | |||
return false; | |||
} | |||
@@ -281,7 +292,6 @@ bool RuntimeModel::DistributeTask() { | |||
GELOGE(FAILED, "DistributeTask failed"); | |||
return false; | |||
} | |||
return true; | |||
} | |||
@@ -293,10 +303,14 @@ bool RuntimeModel::Run() { | |||
return false; | |||
} | |||
GELOGI("Run rtModelExecute success"); | |||
GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||
ret = rtStreamSynchronize(rt_model_stream_); | |||
if (ret != RT_ERROR_NONE) { | |||
if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||
return true; | |||
} | |||
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | |||
return false; | |||
} | |||
@@ -330,6 +344,9 @@ void RuntimeModel::RtStreamDestory() noexcept { | |||
void RuntimeModel::RtLabelDestory() noexcept { | |||
for (size_t i = 0; i < label_list_.size(); i++) { | |||
if (label_list_[i] == nullptr) { | |||
continue; | |||
} | |||
if (rtLabelDestroy(label_list_[i]) != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Destroy label failed! Index: %zu.", i); | |||
return; | |||
@@ -471,11 +488,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | |||
/// and that of unknown shape is zero too. | |||
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | |||
int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||
if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||
elem_num = 1; | |||
} | |||
int64_t elem_num = | |||
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||
if (constant->weight_data.size() < sizeof(uint64_t)) { | |||
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | |||
return false; | |||
@@ -40,13 +40,11 @@ class RuntimeModel { | |||
const std::vector<uint32_t> &GetTaskIdList() const; | |||
const std::vector<uint32_t> &GetStreamIdList() const; | |||
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||
const rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
rtModel_t GetModelHandle() const { return rt_model_handle_; } | |||
bool Run(); | |||
bool CopyInputData(const InputData &input_data); | |||
bool GetInputOutputDescInfo(bool zero_copy, | |||
std::vector<InputOutputDescInfo> *input_desc, | |||
std::vector<InputOutputDescInfo> *output_desc, | |||
std::vector<uint32_t> *input_format, | |||
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
std::vector<uint32_t> *output_format); | |||
private: | |||
@@ -55,7 +53,7 @@ class RuntimeModel { | |||
bool LoadTask(); | |||
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitEvent(uint32_t event_num); | |||
bool InitLabel(uint32_t batch_num); | |||
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
@@ -87,6 +85,7 @@ class RuntimeModel { | |||
std::vector<uint32_t> stream_id_list_{}; | |||
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -26,6 +26,7 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
args_(nullptr), | |||
ext_info_(nullptr), | |||
input_output_addr_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
@@ -41,7 +42,10 @@ AicpuTask::AicpuTask(const ModelContext &model_context, const std::shared_ptr<Ai | |||
} | |||
} | |||
AicpuTask::~AicpuTask() { ReleaseRtMem(&args_); } | |||
AicpuTask::~AicpuTask() { | |||
ReleaseRtMem(&args_); | |||
ReleaseRtMem(&ext_info_); | |||
} | |||
bool AicpuTask::Distribute() { | |||
GELOGI("InitAicpuTask start."); | |||
@@ -51,10 +55,37 @@ bool AicpuTask::Distribute() { | |||
auto io_addrs_num = static_cast<uint32_t>(io_addrs.size()); | |||
auto io_addrs_size = static_cast<uint32_t>(io_addrs_num * sizeof(void *)); | |||
constexpr uint32_t io_addr_offset = sizeof(aicpu::AicpuParamHead); | |||
uint32_t node_def_addr_offset = io_addr_offset + io_addrs_size; | |||
uint32_t args_size = | |||
sizeof(aicpu::AicpuParamHead) + io_addrs_size + static_cast<uint32_t>(task_info_->node_def().size()); | |||
aicpu::AicpuParamHead aicpu_param_head = {args_size, io_addrs_num}; | |||
uint32_t node_def_len_offset = io_addr_offset + io_addrs_size; | |||
uint32_t node_def_addr_offset = node_def_len_offset + sizeof(uint32_t); | |||
uint32_t args_size = sizeof(aicpu::AicpuParamHead) + io_addrs_size + | |||
static_cast<uint32_t>(task_info_->node_def().size()) + sizeof(uint32_t); | |||
aicpu::AicpuParamHead aicpu_param_head; | |||
aicpu_param_head.length = args_size; | |||
aicpu_param_head.ioAddrNum = io_addrs_num; | |||
auto ext_info = task_info_->ext_info(); | |||
uint32_t ext_size = ext_info.size(); | |||
if (ext_info.empty()) { | |||
aicpu_param_head.extInfoLength = 0; | |||
aicpu_param_head.extInfoAddr = 0; | |||
} else { | |||
rtError_t flag = rtMalloc(&ext_info_, ext_size, RT_MEMORY_HBM); | |||
if (flag != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X.", flag); | |||
return false; | |||
} | |||
flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||
RT_MEMCPY_HOST_TO_DEVICE); | |||
if (flag != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | |||
return false; | |||
} | |||
GELOGI("ext info size:", ext_size); | |||
aicpu_param_head.extInfoLength = ext_size; | |||
aicpu_param_head.extInfoAddr = reinterpret_cast<uintptr_t>(ext_info_); | |||
} | |||
// Malloc device memory for args | |||
rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); | |||
@@ -80,6 +111,17 @@ bool AicpuTask::Distribute() { | |||
return false; | |||
} | |||
} | |||
// Memcpy node def | |||
auto size = task_info_->node_def().size(); | |||
rt_ret = | |||
rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_len_offset), sizeof(uint32_t), | |||
reinterpret_cast<const void *>(&size), sizeof(uint32_t), RT_MEMCPY_HOST_TO_DEVICE); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X.", rt_ret); | |||
return false; | |||
} | |||
// Memcpy node def | |||
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + node_def_addr_offset), | |||
task_info_->node_def().size(), reinterpret_cast<const void *>(task_info_->node_def().data()), | |||
@@ -41,6 +41,7 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||
std::shared_ptr<AicpuTaskInfo> task_info_; | |||
void *stream_; | |||
void *args_; | |||
void *ext_info_; | |||
void *input_output_addr_; | |||
}; | |||
} // namespace model_runner | |||
@@ -103,9 +103,9 @@ bool CceTask::Distribute() { | |||
// Modify flowtable addr in args | |||
auto args = const_cast<uint8_t *>(task_info_->args().data()); | |||
auto task_offset = reinterpret_cast<uint16_t *>(const_cast<uint8_t *>(task_info_->args_offset().data())); | |||
if (task_info_->args().size() < (task_offset[0] + sizeof(uint64_t))) { | |||
GELOGE(FAILED, | |||
"(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||
GELOGE(FAILED, "(context.args_offset().data()))[0]:%u + sizeof(uint64_t):%zu > kernelDef.args().size():%zu", | |||
static_cast<uint32_t>(task_offset[0]), sizeof(uint64_t), task_info_->args().size()); | |||
return false; | |||
} | |||
@@ -136,8 +136,7 @@ bool CceTask::Distribute() { | |||
return false; | |||
} | |||
rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), | |||
task_info_->sm_desc().data(), | |||
rt_ret = rtMemcpy(sm_desc_, task_info_->sm_desc().size(), task_info_->sm_desc().data(), | |||
task_info_->sm_desc().size(), RT_MEMCPY_HOST_TO_DEVICE); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
@@ -146,12 +145,8 @@ bool CceTask::Distribute() { | |||
} | |||
// Kernel launch | |||
rt_ret = rtKernelLaunch(stub_func_, | |||
task_info_->block_dim(), | |||
args_, | |||
task_info_->args_size(), | |||
static_cast<rtSmDesc_t *>(sm_desc_), | |||
stream_); | |||
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, task_info_->args_size(), | |||
static_cast<rtSmDesc_t *>(sm_desc_), stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
@@ -33,7 +33,7 @@ class EventRecordTask : public TaskRepeater<EventRecordTaskInfo> { | |||
private: | |||
std::shared_ptr<EventRecordTaskInfo> task_info_; | |||
rtStream_t stream_; | |||
rtEvent_t event_; | |||
rtEvent_t event_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -33,7 +33,7 @@ class EventWaitTask : public TaskRepeater<EventWaitTaskInfo> { | |||
private: | |||
std::shared_ptr<EventWaitTaskInfo> task_info_; | |||
rtStream_t stream_; | |||
rtEvent_t event_; | |||
rtEvent_t event_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
@@ -115,7 +115,6 @@ bool HcclTask::Distribute() { | |||
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
(void)rtStreamDestroy(stream); | |||
return false; | |||
} | |||
@@ -129,8 +128,6 @@ bool HcclTask::Distribute() { | |||
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | |||
ge_task.stream = stream_; | |||
GETaskKernelHcclInfo kernel_hccl_info; | |||
ge_task.kernelHcclInfo.emplace_back(kernel_hccl_info); | |||
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||
@@ -0,0 +1,70 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_goto_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
label_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
auto stream_list = model_context.stream_list(); | |||
auto label_list = model_context.label_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
uint32_t label_id = task_info->label_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
GELOGW("Stream/Label id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
label_ = label_list[label_id]; | |||
} | |||
LabelGotoTask::~LabelGotoTask() {} | |||
bool LabelGotoTask::Distribute() { | |||
GELOGI("LabelGotoTask Distribute start."); | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
return false; | |||
} | |||
rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
public: | |||
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
~LabelGotoTask() override; | |||
bool Distribute() override; | |||
private: | |||
std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
void *stream_; | |||
void *label_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ |
@@ -0,0 +1,70 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_set_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
label_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
auto stream_list = model_context.stream_list(); | |||
auto label_list = model_context.label_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
uint32_t label_id = task_info->label_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
GELOGW("Stream/Label id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
label_ = label_list[label_id]; | |||
} | |||
LabelSetTask::~LabelSetTask() {} | |||
bool LabelSetTask::Distribute() { | |||
GELOGI("LabelSetTask Distribute start."); | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (label_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "label is null!"); | |||
return false; | |||
} | |||
rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,41 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
public: | |||
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
~LabelSetTask() override; | |||
bool Distribute() override; | |||
private: | |||
std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
void *stream_; | |||
void *label_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ |
@@ -0,0 +1,131 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_runtime/task/label_switch_task.h" | |||
#include "ge_runtime/task/task_factory.h" | |||
namespace ge { | |||
namespace model_runner { | |||
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
task_info_(task_info), | |||
stream_(nullptr), | |||
all_label_resource_(), | |||
label_info_(nullptr) { | |||
if (task_info_ == nullptr) { | |||
GELOGW("task_info_ is null!"); | |||
return; | |||
} | |||
all_label_resource_ = model_context.label_list(); | |||
auto stream_list = model_context.stream_list(); | |||
uint32_t stream_id = task_info->stream_id(); | |||
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
if (stream_id >= stream_list.size()) { | |||
GELOGW("Stream id invalid."); | |||
return; | |||
} | |||
stream_ = stream_list[stream_id]; | |||
} | |||
LabelSwitchTask::~LabelSwitchTask() { | |||
if (label_info_ != nullptr) { | |||
rtError_t rt_ret = rtFree(label_info_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
} | |||
label_info_ = nullptr; | |||
} | |||
} | |||
bool LabelSwitchTask::Distribute() { | |||
GELOGI("LabelSwitchTask Distribute start."); | |||
if (!CheckParamValid()) { | |||
return false; | |||
} | |||
const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
uint32_t label_index = label_index_list[i]; | |||
if (label_index >= all_label_resource_.size()) { | |||
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
all_label_resource_.size()); | |||
return false; | |||
} | |||
label_list[i] = all_label_resource_[label_index]; | |||
GELOGI("Case %zu: label id %zu.", i, label_index); | |||
} | |||
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
return false; | |||
} | |||
GELOGI("DistributeTask end."); | |||
return true; | |||
} | |||
bool LabelSwitchTask::CheckParamValid() { | |||
if (stream_ == nullptr) { | |||
GELOGE(PARAM_INVALID, "stream is null!"); | |||
return false; | |||
} | |||
if (task_info_->label_list().empty()) { | |||
GELOGE(PARAM_INVALID, "label_list is empty."); | |||
return false; | |||
} | |||
if (task_info_->label_size() != task_info_->label_list().size()) { | |||
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
task_info_->label_size()); | |||
return false; | |||
} | |||
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
return false; | |||
} | |||
if (label_info_ != nullptr) { | |||
GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
return false; | |||
} | |||
return true; | |||
} | |||
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
} // namespace model_runner | |||
} // namespace ge |
@@ -0,0 +1,44 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
#include <memory> | |||
#include "ge_runtime/task/task.h" | |||
namespace ge { | |||
namespace model_runner { | |||
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
public: | |||
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
~LabelSwitchTask() override; | |||
bool Distribute() override; | |||
private: | |||
bool CheckParamValid(); | |||
std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
void *stream_; | |||
std::vector<void *> all_label_resource_; | |||
void *label_info_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ |
@@ -37,6 +37,7 @@ class StreamSwitchTask : public TaskRepeater<StreamSwitchTaskInfo> { | |||
void *stream_; | |||
std::vector<rtStream_t> stream_list_; | |||
}; | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_STREAM_SWITCH_TASK_H_ |
@@ -42,7 +42,7 @@ class Task { | |||
template <class T> | |||
class TaskRepeater : public Task { | |||
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); /*lint !e30*/ | |||
static_assert(std::is_base_of<TaskInfo, T>(), "Wrong TaskInfo Type!"); | |||
public: | |||
TaskRepeater(const ModelContext &model_context, std::shared_ptr<T> task_info) {} | |||
@@ -81,6 +81,7 @@ class TaskFactory { | |||
std::shared_ptr<task_info_clazz> concrete_task_info = std::static_pointer_cast<task_info_clazz>(task_info); \ | |||
return std::make_shared<task_clazz>(model_context, concrete_task_info); \ | |||
}); | |||
} // namespace model_runner | |||
} // namespace ge | |||
#endif // GE_GE_RUNTIME_TASK_TASK_FACTORY_H_ |
@@ -27,10 +27,10 @@ namespace ge { | |||
namespace model_runner { | |||
class DavinciModel { | |||
public: | |||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, /*lint !e151*/ | |||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||
const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | |||
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, /*lint !e151*/ | |||
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, /*lint !e1049*/ | |||
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||
const std::vector<model_runner::OpInfoPtr> &variable_info_list, | |||
const std::vector<uint32_t> &wait_active_stream_list, | |||
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||
@@ -68,12 +68,12 @@ class DavinciModel { | |||
uint32_t GetBatchNum() const { return batch_num_; } | |||
uint32_t GetEventNum() const { return event_num_; } | |||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } /*lint !e1413*/ | |||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } /*lint !e1413*/ | |||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||
int32_t GetPriority() const { return priority_; } | |||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } /*lint !e151*/ | |||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | |||
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | |||
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | |||
@@ -81,7 +81,7 @@ class DavinciModel { | |||
private: | |||
std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||
std::vector<std::shared_ptr<OpInfo>> data_info_list_; /*lint !e151*/ | |||
std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||
std::vector<std::shared_ptr<OpInfo>> output_info_list_; | |||
std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | |||
std::vector<model_runner::OpInfoPtr> variable_info_list_; | |||
@@ -52,11 +52,8 @@ class ModelRunner { | |||
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||
bool GetInputOutputDescInfo(uint32_t model_id, | |||
bool zero_copy, | |||
std::vector<InputOutputDescInfo> *input_desc, | |||
std::vector<InputOutputDescInfo> *output_desc, | |||
std::vector<uint32_t> *input_format, | |||
bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||
std::vector<uint32_t> *output_format); | |||
private: | |||
@@ -161,12 +161,13 @@ class TbeTaskInfo : public TaskInfo { | |||
class AicpuTaskInfo : public TaskInfo { | |||
public: | |||
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||
const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||
const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs, | |||
const std::vector<void *> &output_data_addrs, bool dump_flag) | |||
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
so_name_(so_name), | |||
kernel_name_(kernel_name), | |||
node_def_(node_def), | |||
ext_info_(ext_info), | |||
input_data_addrs_(input_data_addrs), | |||
output_data_addrs_(output_data_addrs) {} | |||
~AicpuTaskInfo() override {} | |||
@@ -176,11 +177,13 @@ class AicpuTaskInfo : public TaskInfo { | |||
const std::string &node_def() const { return node_def_; } | |||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||
const std::string &ext_info() const { return ext_info_; } | |||
private: | |||
std::string so_name_; | |||
std::string kernel_name_; | |||
std::string node_def_; | |||
std::string ext_info_; | |||
std::vector<void *> input_data_addrs_; | |||
std::vector<void *> output_data_addrs_; | |||
}; | |||
@@ -293,19 +296,19 @@ class HcclTaskInfo : public TaskInfo { | |||
hcom_distribute_task_(hcom_distribute_task) {} | |||
~HcclTaskInfo() override {} | |||
const std::string &hccl_type() const { return hccl_type_; } /*lint !e1413*/ | |||
const std::string &hccl_type() const { return hccl_type_; } | |||
void *input_data_addr() const { return input_data_addr_; } | |||
void *output_data_addr() const { return output_data_addr_; } | |||
void *workspace_addr() const { return workspace_addr_; } | |||
int64_t workspace_size() const { return workspace_size_; } | |||
int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||
const std::vector<uint8_t> &private_def() const { return private_def_; } /*lint !e1413*/ | |||
const std::vector<uint8_t> &private_def() const { return private_def_; } | |||
void *ops_kernel_store() const { return ops_kernel_store_; } | |||
int32_t count() const { return count_; } | |||
int64_t root_id() const { return root_id_; } | |||
int64_t op_type() const { return op_type_; } | |||
int64_t data_type() const { return data_type_; } | |||
const std::string group() const { return group_; } | |||
const std::string &group() const { return group_; } | |||
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | |||
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | |||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | |||