@@ -88,11 +88,9 @@ else () | |||
find_module(hccl libhccl.so ${GE_LIB_PATH}) | |||
find_module(adump_server libadump_server.a ${GE_LIB_PATH}) | |||
find_module(runtime libruntime.so ${GE_LIB_PATH}) | |||
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH}) | |||
find_module(resource libresource.so ${GE_LIB_PATH}) | |||
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH}) | |||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | |||
else() | |||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | |||
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | |||
@@ -108,7 +106,6 @@ else () | |||
elseif(PLATFORM STREQUAL "inference") | |||
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_ACL_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||
if(PRODUCT STREQUAL "flr3") | |||
elseif(PRODUCT STREQUAL "flr1") | |||
@@ -120,10 +117,9 @@ else () | |||
endif() | |||
elseif(PLATFORM STREQUAL "all") | |||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
find_module(runtime libruntime.so ${ASCEND_ATC_DIR}) | |||
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}) | |||
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR}) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub) | |||
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR}) | |||
else() | |||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||
@@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||
endif() | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||
set(MD5 "") | |||
if (GE_PB_PKG) | |||
set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz") | |||
set(MD5 "1a865b93bacfa963201af3f75b7bd64c") | |||
else() | |||
set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||
set(MD5 "") | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||
set(MD5 "") | |||
else() | |||
set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||
set(MD5 "1a865b93bacfa963201af3f75b7bd64c") | |||
endif () | |||
endif () | |||
set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private") | |||
@@ -112,6 +112,8 @@ set(EXECUTOR_SRC_LIST | |||
"common/dump/dump_op.cc" | |||
"common/dump/exception_dumper.cc" | |||
"common/dump/opdebug_register.cc" | |||
"common/ge/op_tiling_manager.cc" | |||
"common/ge/plugin_manager.cc" | |||
"common/profiling/ge_profiling.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"executor/ge_executor.cc" | |||
@@ -259,6 +261,8 @@ set(EXECUTOR_SRC_LIST | |||
set(COMPILER_SRC_LIST | |||
"analyzer/analyzer.cc" | |||
"common/dump/dump_op.cc" | |||
"common/ge/op_tiling_manager.cc" | |||
"common/ge/plugin_manager.cc" | |||
"common/helper/model_cache_helper.cc" | |||
"common/profiling/profiling_manager.cc" | |||
"engine_manager/dnnengine_manager.cc" | |||
@@ -619,7 +623,6 @@ target_compile_definitions(ge_compiler PRIVATE | |||
REUSE_MEMORY=1 | |||
FMK_SUPPORT_DUMP | |||
FMK_HOST_INFER | |||
COMPILE_OMG_PACKAGE | |||
google=ascend_private | |||
FUNC_VISIBILITY | |||
$<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC> | |||
@@ -681,8 +684,7 @@ target_link_libraries(ge_compiler PRIVATE | |||
c_sec | |||
error_manager | |||
slog | |||
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>> | |||
$<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>> | |||
runtime | |||
opt_feature | |||
-Wl,--as-needed | |||
json | |||
@@ -350,7 +350,7 @@ Status FftsPlusTaskInfo::InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def | |||
i_cache_prefetch_cnt_2)); | |||
ctx->tailTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_task_start_pc) & 0XFFFFFFFF); | |||
ctx->tailTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_task_start_pc) >> 32) & 0X0000FFFF); | |||
uint32_t i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||
uint32_t i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||
ctx->icachePrefetchCnt = static_cast<uint16_t>(i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | |||
if (ctx_def.src_slot_size() != kSrcSlotNum) { | |||
@@ -526,8 +526,7 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c | |||
ctx->tailAicTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) & 0XFFFFFFFF); | |||
ctx->tailAicTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) >> 32) & | |||
0X0000FFFF); | |||
uint32_t aic_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||
// TODO | |||
uint32_t aic_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2); | |||
ctx->icachePrefetchCnt = static_cast<uint16_t>(aic_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | |||
uint32_t i_cache_prefetch_cnt_3; | |||
@@ -545,9 +544,10 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c | |||
ctx->tailAivTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) & 0XFFFFFFFF); | |||
ctx->tailAivTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) >> 32) & | |||
0X0000FFFF); | |||
uint32_t aiv_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); | |||
uint32_t aiv_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4); | |||
// TODO | |||
ctx->icachePrefetchCnt = static_cast<uint16_t>(aiv_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111 | |||
ctx->icachePrefetchCnt = static_cast<uint16_t>( | |||
std::min(aic_i_cache_prefetch_cnt, aiv_i_cache_prefetch_cnt) & 0X0000001F); // 5 bits, 0001,1111 | |||
if (ctx_def.src_slot_size() != kSrcSlotNum) { | |||
REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should be %d, but %d exactly", | |||
@@ -913,11 +913,11 @@ void FftsPlusTaskInfo::SetAdditionalDatatoCtx(const domi::FftsPlusTaskDef &task_ | |||
Status FftsPlusTaskInfo::UpdateMixAicAivCtxParam(const domi::FftsPlusMixAicAivCtxDef &ctx_def, size_t ctx_idx) { | |||
if (ctx_additional_data_.count(ctx_idx) == 0) { | |||
GELOGD("ctx idx:%d not in ctx additional data"); | |||
GELOGD("ctx idx:%zu not in ctx additional data"); | |||
return SUCCESS; | |||
} | |||
if (ctx_additional_data_[ctx_idx].count(kModeInArgsFirstField) == 0) { | |||
GELOGD("ctx idx:%d need not to save mode in args first field"); | |||
GELOGD("ctx idx:%zu need not to save mode in args first field"); | |||
return SUCCESS; | |||
} | |||
if (rtApp_addr_ == 0) { | |||
@@ -20,6 +20,7 @@ | |||
#include "graph/manager/graph_mem_manager.h" | |||
#include "graph/manager/trans_var_data_utils.h" | |||
#include "graph/utils/type_utils.h" | |||
#include "graph/ge_context.h" | |||
using std::map; | |||
using std::string; | |||
@@ -767,25 +768,52 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap | |||
return var_resource_->GetChangedGraphId(var_name, graph_id); | |||
} | |||
Status VarManager::GetTotalMemorySize(size_t &total_mem_size) { | |||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", | |||
GetContext().DeviceId(), rt_ret); | |||
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||
return RT_FAILED; | |||
} | |||
size_t free_mem = 0; | |||
rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); | |||
GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); | |||
return RT_FAILED; | |||
} | |||
rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||
if (rt_ret != RT_ERROR_NONE) { | |||
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||
GetContext().DeviceId(), rt_ret); | |||
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||
return RT_FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||
auto it = options.find(GRAPH_MEMORY_MAX_SIZE); | |||
if (it == options.end()) { | |||
graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize; | |||
} else { | |||
string graph_memory_manager_malloc_max_size = it->second; | |||
size_t total_mem_size = 0; | |||
GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size)); | |||
GEEVENT("Total memory size is %zu", total_mem_size); | |||
graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio); | |||
var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio); | |||
auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE); | |||
if (it1 != options.end()) { | |||
string graph_memory_manager_malloc_max_size = it1->second; | |||
ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); | |||
if (ret != SUCCESS) { | |||
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | |||
return ge::GE_GRAPH_OPTIONS_INVALID; | |||
} | |||
GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | |||
} | |||
it = options.find(VARIABLE_MEMORY_MAX_SIZE); | |||
if (it == options.end()) { | |||
var_mem_max_size_ = kMemoryVarManagerMallocSize; | |||
} else { | |||
string memory_var_manager_malloc_size = it->second; | |||
auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE); | |||
if (it2 != options.end()) { | |||
string memory_var_manager_malloc_size = it2->second; | |||
ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); | |||
if (ret != SUCCESS) { | |||
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | |||
@@ -793,6 +821,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||
} | |||
} | |||
GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_); | |||
var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer; | |||
if (var_mem_logic_base_ > kMaxMemorySize) { | |||
REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | |||
@@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL; | |||
const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY"; | |||
const uint64_t kSessionMemAlignSize = 512; | |||
const size_t kSessionMemAlignUnit = 2; | |||
const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0; | |||
const double kVarMemoryManagerMallocRatio = 5.0 / 32.0; | |||
enum MemStatus { | |||
NORMAL = 0, | |||
@@ -301,6 +303,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||
mutable std::recursive_mutex mutex_; | |||
Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size); | |||
Status GetTotalMemorySize(size_t &total_mem_size); | |||
}; | |||
class VarManagerPool { | |||
@@ -60,7 +60,6 @@ const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | |||
const char *const kForceInfershape = "_force_infershape_when_running"; | |||
const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | |||
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; | |||
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
Status SetOutputNameAttr(ComputeGraph &graph) { | |||
@@ -519,170 +518,6 @@ Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, | |||
const InDataAnchorPtr &in_data_anchor) { | |||
GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor), | |||
"[Invoke][Unlink] failed to unlink %s:%d from %s:%d", | |||
out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); | |||
GELOGD("Succeeded in unlinking %s:%d from %s:%d", | |||
out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetIdx()); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) { | |||
GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "[Invoke][LinkTo]Failed to link %s:%d to %s:%d", | |||
out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetIdx()); | |||
GELOGD("Succeeded in linking %s:%d to %s:%d", | |||
out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor->GetIdx(), | |||
in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
in_data_anchor->GetIdx()); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { | |||
const auto &wrapped_node = graph.GetParentNode(); | |||
std::set<NodePtr> root_nodes; | |||
for (const auto &node : graph.GetDirectNode()) { | |||
GE_CHECK_NOTNULL(node); | |||
if (node->GetType() != DATA_TYPE) { | |||
if (node->GetInDataNodes().empty()) { | |||
root_nodes.emplace(node); | |||
} | |||
continue; | |||
} | |||
auto data_op_desc = node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(data_op_desc); | |||
uint32_t parent_index = 0; | |||
if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
GELOGE(FAILED, "[Invoke][GetInt] failed, node:[%s] attr:[%s]", | |||
data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||
REPORT_CALL_ERROR("E19999", "GetInt failed, node:[%s] attr:[%s]", | |||
data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||
return FAILED; | |||
} | |||
auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index); | |||
GE_CHECK_NOTNULL(wrapped_node_in_anchor); | |||
auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor(); | |||
if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) { | |||
continue; | |||
} | |||
wrapped_node_in_anchor->UnlinkAll(); | |||
// link src to outputs of DataNode | |||
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
GE_CHECK_NOTNULL(out_data_anchor); | |||
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
auto dst_node = peer_in_data_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(dst_node); | |||
const auto in_nodes = dst_node->GetInDataNodes(); | |||
if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { | |||
root_nodes.emplace(dst_node); | |||
} | |||
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); | |||
GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); | |||
} | |||
} | |||
} | |||
// transfer in control edges to all root nodes | |||
for (auto &root_node : root_nodes) { | |||
auto in_nodes = root_node->GetInAllNodes(); | |||
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | |||
for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | |||
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { | |||
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | |||
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | |||
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | |||
} | |||
} | |||
} | |||
wrapped_node->GetInControlAnchor()->UnlinkAll(); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { | |||
const auto &parent_node = graph.GetParentNode(); | |||
const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); | |||
if (net_output_node == nullptr) { | |||
GELOGD("Graph has no netoutput no need to merge"); | |||
return SUCCESS; | |||
} | |||
const auto &net_output_desc = net_output_node->GetOpDesc(); | |||
GE_CHECK_NOTNULL(net_output_desc); | |||
auto all_in_nodes = net_output_node->GetInAllNodes(); | |||
auto all_out_nodes = parent_node->GetOutAllNodes(); | |||
net_output_node->GetInControlAnchor()->UnlinkAll(); | |||
parent_node->GetOutControlAnchor()->UnlinkAll(); | |||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||
auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(src_out_anchor); | |||
GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode()); | |||
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor)); | |||
auto index = in_data_anchor->GetIdx(); | |||
auto input_desc = net_output_desc->MutableInputDesc(index); | |||
if (input_desc == nullptr) { | |||
GELOGE(INTERNAL_ERROR, "[Invoke][MutableInputDesc][%s] Failed to get input desc[%d]", | |||
net_output_desc->GetName().c_str(), index); | |||
REPORT_CALL_ERROR("E19999", "[%s] Failed to get input desc[%d].", net_output_desc->GetName().c_str(), index); | |||
return INTERNAL_ERROR; | |||
} | |||
uint32_t parent_index = 0; | |||
if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.", | |||
graph.GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||
continue; | |||
} | |||
const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index); | |||
GE_CHECK_NOTNULL(parent_out_anchor); | |||
for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { | |||
if (dst_in_anchor == nullptr) { | |||
continue; | |||
} | |||
GE_CHECK_NOTNULL(dst_in_anchor->GetOwnerNode()); | |||
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor)); | |||
GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor)); | |||
} | |||
} | |||
// transfer out control edges | |||
std::set<NodePtr> in_node_set(all_in_nodes.begin(), all_in_nodes.end()); | |||
std::set<NodePtr> out_node_set(all_out_nodes.begin(), all_out_nodes.end()); | |||
for (auto &src_node : in_node_set) { | |||
GELOGD("[%s] process in node.", src_node->GetName().c_str()); | |||
auto out_nodes = src_node->GetOutAllNodes(); | |||
std::set<NodePtr> node_set(out_nodes.begin(), out_nodes.end()); | |||
for (auto &dst_node : out_node_set) { | |||
if (node_set.count(dst_node) == 0) { | |||
src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); | |||
GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { | |||
merged_graph = MakeShared<ComputeGraph>("MergedGraph"); | |||
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag()); | |||
@@ -716,9 +551,21 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG | |||
} | |||
} | |||
} | |||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), | |||
const auto &filter = [](const ComputeGraphPtr &graph) { | |||
const auto &parent_node = graph->GetParentNode(); | |||
if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { | |||
return false; | |||
} | |||
if ((parent_node->GetType() != PARTITIONEDCALL) || | |||
(parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { | |||
return false; | |||
} | |||
return graph->GetGraphUnknownFlag(); | |||
}; | |||
GE_CHK_GRAPH_STATUS_RET(GraphUtils::UnfoldSubgraph(subgraph, filter), | |||
"[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.", | |||
subgraph->GetName().c_str()); | |||
subgraph->GetName().c_str()) | |||
} | |||
// invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | |||
@@ -744,56 +591,6 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, | |||
ComputeGraphPtr &parent_graph, | |||
ComputeGraph &sub_graph) { | |||
auto parent_node = sub_graph.GetParentNode(); | |||
GE_CHECK_NOTNULL(parent_node); | |||
GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), | |||
"[Invoke][MergeInputNodes][%s] Failed to merge data nodes for subgraph", | |||
sub_graph.GetName().c_str()); | |||
GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), | |||
"[Invoke][MergeNetOutputNode][%s] Failed to merge net output nodes for subgraph", | |||
sub_graph.GetName().c_str()); | |||
GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str()); | |||
for (auto &sub_node : sub_graph.GetDirectNode()) { | |||
auto sub_op_type = sub_node->GetType(); | |||
if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { | |||
continue; | |||
} | |||
if (sub_op_type == PARTITIONEDCALL) { | |||
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); | |||
GE_CHECK_NOTNULL(sub_sub_graph); | |||
if (sub_sub_graph->GetGraphUnknownFlag()) { | |||
GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), | |||
"[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph", | |||
sub_sub_graph->GetName().c_str()); | |||
continue; | |||
} | |||
} | |||
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { | |||
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { | |||
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); | |||
GE_CHECK_NOTNULL(sub_sub_graph); | |||
sub_sub_graph->SetParentGraph(parent_graph); | |||
} | |||
} | |||
parent_graph->AddNode(sub_node); | |||
GELOGD("[%s::%s] added to parent graph: [%s].", | |||
sub_graph.GetName().c_str(), | |||
sub_node->GetName().c_str(), | |||
parent_graph->GetName().c_str()); | |||
sub_node->SetOwnerComputeGraph(parent_graph); | |||
} | |||
GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str()); | |||
root_graph->RemoveSubgraph(sub_graph.GetName()); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | |||
const NodeItem &node_item, | |||
bool is_root_graph) { | |||
@@ -39,16 +39,11 @@ class HybridModelBuilder { | |||
private: | |||
static Status UpdateAnchorStatus(const NodePtr &node); | |||
static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor); | |||
static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor); | |||
static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor); | |||
static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index); | |||
static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index); | |||
static Status HandleDtString(const GeTensor &tensor, void *var_addr); | |||
static Status MergeInputNodes(ComputeGraph &compute_graph); | |||
static Status MergeNetOutputNode(ComputeGraph &compute_graph); | |||
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); | |||
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); | |||
static Status BuildInputMapping(GraphItem &graph_item, | |||
std::vector<NodeItem *> &data_nodes, | |||
bool is_root_graph); | |||
@@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE | |||
target_compile_definitions(atc_atc.bin PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
COMPILE_OMG_PACKAGE | |||
google=ascend_private | |||
LOG_CPP | |||
FUNC_VISIBILITY | |||
@@ -48,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE | |||
target_link_options(atc_atc.bin PRIVATE | |||
-Wl,-Bsymbolic | |||
-Wl,-rpath-link,${ASCEND_ATC_DIR}/stub | |||
) | |||
target_link_libraries(atc_atc.bin PRIVATE | |||
@@ -62,8 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE | |||
parser_common | |||
gflags | |||
json | |||
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>> | |||
$<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>> | |||
runtime | |||
slog | |||
static_mmpa | |||
-lrt | |||
@@ -92,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE | |||
target_compile_definitions(fwk_atc.bin PRIVATE | |||
PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
COMPILE_OMG_PACKAGE | |||
google=ascend_private | |||
LOG_CPP | |||
FUNC_VISIBILITY | |||
@@ -193,6 +193,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) { | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) { | |||
*free = 512UL * 1024UL * 1024UL; | |||
*total = 1024UL * 1024UL * 1024UL; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; } | |||
rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; } | |||
@@ -692,6 +692,7 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/manager/run_graph_unittest.cc" | |||
"graph/partition/dynamic_shape_partition_unittest.cc" | |||
"graph/manager/graph_manager_unittest.cc" | |||
"graph/manager/graph_var_manager_unittest.cc" | |||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||
"graph/optimize/graph_optimize_unittest.cc" | |||
"session/omg_omg_unittest.cc" | |||
@@ -79,13 +79,12 @@ public: | |||
additionaldata1->add_context_id(5); | |||
} | |||
void InitAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_aic_aiv_ctx(); | |||
void InitAicAivCtx(domi::FftsPlusAicAivCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
ctxdef->set_pred_cnt(1); | |||
for (int i = 0; i < RT_CTX_SUCCESSOR_NUM; ++i) { | |||
for (int i = 1; i < RT_CTX_SUCCESSOR_NUM; ++i) { | |||
ctxdef->add_successor_list(1); // 16 bits, len = 26 | |||
} | |||
ctxdef->set_stat(1); | |||
@@ -113,8 +112,7 @@ public: | |||
} | |||
} | |||
void InitMixAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusMixAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_mix_aic_aiv_ctx(); | |||
void InitMixAicAivCtx(domi::FftsPlusMixAicAivCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -153,8 +151,7 @@ public: | |||
} | |||
} | |||
void InitSdmaCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusSdmaCtxDef *ctxdef = fftsplusctxdef->mutable_sdma_ctx(); | |||
void InitSdmaCtx(domi::FftsPlusSdmaCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -184,8 +181,7 @@ public: | |||
ctxdef->set_tail_data_len(1); | |||
} | |||
void InitNotifyCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusNotifyCtxDef *ctxdef = fftsplusctxdef->mutable_notify_ctx(); | |||
void InitNotifyCtx(domi::FftsPlusNotifyCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -201,8 +197,7 @@ public: | |||
ctxdef->set_notify_id_base(1); | |||
} | |||
void InitWriteValueCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusWriteValueCtxDef *ctxdef = fftsplusctxdef->mutable_write_value_ctx(); | |||
void InitWriteValueCtx(domi::FftsPlusWriteValueCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -227,8 +222,7 @@ public: | |||
ctxdef->add_write_value(1); | |||
} | |||
void InitAicpuCtxCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusAicpuCtxDef *ctxdef = fftsplusctxdef->mutable_aicpu_ctx(); | |||
void InitAicpuCtxCtx(domi::FftsPlusAicpuCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -260,8 +254,7 @@ public: | |||
ctxdef->set_task_param_offset(32); | |||
} | |||
void InitDataCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusDataCtxDef *ctxdef = fftsplusctxdef->mutable_data_ctx(); | |||
void InitDataCtx(domi::FftsPlusDataCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_cnt_init(1); | |||
@@ -293,8 +286,7 @@ public: | |||
ctxdef->set_tail_stride_inner(1); | |||
} | |||
void InitAtStartCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusAtStartCtxDef *ctxdef = fftsplusctxdef->mutable_at_start_ctx(); | |||
void InitAtStartCtx(domi::FftsPlusAtStartCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(1); | |||
ctxdef->set_pred_cnt_init(1); | |||
@@ -309,8 +301,7 @@ public: | |||
ctxdef->set_thread_window_size(1); | |||
} | |||
void InitAtEndCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusAtEndCtxDef *ctxdef = fftsplusctxdef->mutable_at_end_ctx(); | |||
void InitAtEndCtx(domi::FftsPlusAtEndCtxDef *ctxdef) { | |||
ctxdef->set_at_start_slot_num(12); | |||
ctxdef->set_out_label_slot_num(12); | |||
ctxdef->set_aten(1); | |||
@@ -325,8 +316,7 @@ public: | |||
ctxdef->set_thread_id(1); | |||
} | |||
void InitLabelCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusLabelCtxDef *ctxdef = fftsplusctxdef->mutable_label_ctx(); | |||
void InitLabelCtx(domi::FftsPlusLabelCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_pred_cnt_init(1); | |||
ctxdef->set_pred_cnt(1); | |||
@@ -335,8 +325,7 @@ public: | |||
} | |||
} | |||
void InitCaseSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusCaseSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_case_switch_ctx(); | |||
void InitCaseSwitchCtx(domi::FftsPlusCaseSwitchCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(32); | |||
ctxdef->set_start_label_id(32); | |||
@@ -366,8 +355,7 @@ public: | |||
ctxdef->set_load_addr1_offset(32); | |||
} | |||
void InitCaseDefaultCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusCaseDefaultCtxDef *ctxdef = fftsplusctxdef->mutable_case_default_ctx(); | |||
void InitCaseDefaultCtx(domi::FftsPlusCaseDefaultCtxDef *ctxdef) { | |||
ctxdef->set_successor_num(26); | |||
ctxdef->set_aten(32); | |||
ctxdef->set_start_label_id(1); | |||
@@ -379,8 +367,7 @@ public: | |||
} | |||
} | |||
void InitCondSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) { | |||
domi::FftsPlusCondSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_cond_switch_ctx(); | |||
void InitCondSwitchCtx(domi::FftsPlusCondSwitchCtxDef *ctxdef) { | |||
ctxdef->set_true_successor_num(12); | |||
ctxdef->set_false_successor_num(14); | |||
ctxdef->set_aten(32); | |||
@@ -444,35 +431,38 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_software_ctx) { | |||
InitTaskSQEInfo(ffts_plus_task_def); | |||
InitTaskAdditionalDataInfo(ffts_plus_task_def); | |||
domi::FftsPlusCtxDef *startctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
startctx->set_op_index(0); | |||
startctx->set_hardware_ctx_type(0); | |||
startctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START)); | |||
InitAtStartCtx(startctx); | |||
domi::FftsPlusCtxDef *fftsplusstartctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
fftsplusstartctx->set_op_index(0); | |||
fftsplusstartctx->set_hardware_ctx_type(0); | |||
fftsplusstartctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START)); | |||
domi::FftsPlusAtStartCtxDef *startctxdef = fftsplusstartctx->mutable_at_start_ctx(); | |||
InitAtStartCtx(startctxdef); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | |||
startctx->at_start_ctx().add_successor_list(1); | |||
startctxdef->add_successor_list(1); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *endctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
endctx->set_op_index(0); | |||
endctx->set_hardware_ctx_type(0); | |||
endctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END)); | |||
InitAtEndCtx(endctx); | |||
domi::FftsPlusCtxDef *fftsplusendctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
fftsplusendctx->set_op_index(0); | |||
fftsplusendctx->set_hardware_ctx_type(0); | |||
fftsplusendctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END)); | |||
domi::FftsPlusAtEndCtxDef *endctxdef = fftsplusendctx->mutable_at_end_ctx(); | |||
InitAtEndCtx(endctxdef); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | |||
endctx->at_end_ctx().add_succ_at_start_slot(1); | |||
endctxdef->add_succ_at_start_slot(1); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | |||
endctx->at_end_ctx().add_succ_out_label_slot(1); | |||
endctxdef->add_succ_out_label_slot(1); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *labelctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
labelctx->set_op_index(0); | |||
labelctx->set_hardware_ctx_type(0); | |||
labelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL)); | |||
InitLabelCtx(labelctx); | |||
domi::FftsPlusCtxDef *fftspluslabelctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
fftspluslabelctx->set_op_index(0); | |||
fftspluslabelctx->set_hardware_ctx_type(0); | |||
fftspluslabelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL)); | |||
domi::FftsPlusLabelCtxDef *labelctxdef = fftsplusctxdef->mutable_label_ctx(); | |||
InitLabelCtx(labelctxdef); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED); | |||
labelctx->label_ctx().add_successor_list(1); | |||
labelctxdef->add_successor_list(1); | |||
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS); | |||
} | |||
@@ -501,102 +491,111 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx) { | |||
aicaivctx->set_op_index(0); | |||
aicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AIV)); | |||
aicaivctx->set_software_ctx_type(0); | |||
InitAicAivCtx(aicaivctx); | |||
domi::FftsPlusAicAivCtxDef *aicaivdef = aicaivctx->mutable_aic_aiv_ctx(); | |||
InitAicAivCtx(aicaivdef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
aicaivctx->aic_aiv_ctx().add_successor_list(1); | |||
aicaivdef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
aicaivctx->aic_aiv_ctx().add_kernel_name("aivtest"); | |||
aicaivdef->add_kernel_name("aivtest"); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
aicaivctx->aic_aiv_ctx().add_src_slot(1); | |||
aicaivdef->add_src_slot(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *mixaicaivctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
mixaicaivctx->set_op_index(0); | |||
mixaicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_MIX_AIC)); | |||
mixaicaivctx->set_software_ctx_type(0); | |||
InitMixAicAivCtx(mixaicaivctx); | |||
domi::FftsPlusMixAicAivCtxDef *mixctxdef = mixaicaivctx->mutable_mix_aic_aiv_ctx(); | |||
InitMixAicAivCtx(mixctxdef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
mixaicaivctx->mix_aic_aiv_ctx().add_successor_list(1); | |||
mixctxdef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
mixaicaivctx->mix_aic_aiv_ctx().add_kernel_name("mixaiv"); | |||
mixctxdef->add_kernel_name("mixaiv"); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
mixaicaivctx->mix_aic_aiv_ctx().add_src_slot(1); | |||
mixctxdef->add_src_slot(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *notifyctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
notifyctx->set_op_index(0); | |||
notifyctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_NOTIFY_WAIT)); | |||
notifyctx->set_software_ctx_type(0); | |||
InitNotifyCtx(notifyctx); | |||
domi::FftsPlusNotifyCtxDef *notifydef = notifyctx->mutable_notify_ctx(); | |||
InitNotifyCtx(notifydef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
notifyctx->notify_ctx().add_successor_list(1); | |||
notifydef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *sdmactx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
sdmactx->set_op_index(0); | |||
sdmactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_SDMA)); | |||
sdmactx->set_software_ctx_type(0); | |||
InitSdmaCtx(sdmactx); | |||
domi::FftsPlusSdmaCtxDef *smdadef = sdmactx->mutable_sdma_ctx(); | |||
InitSdmaCtx(smdadef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
sdmactx->sdma_ctx().add_successor_list(1); | |||
smdadef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *writevalctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
writevalctx->set_op_index(0); | |||
writevalctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_WRITE_VALUE)); | |||
writevalctx->set_software_ctx_type(0); | |||
InitWriteValueCtx(writevalctx); | |||
domi::FftsPlusWriteValueCtxDef *writedef = writevalctx->mutable_write_value_ctx(); | |||
InitWriteValueCtx(writedef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
writevalctx->write_value_ctx().add_successor_list(1); | |||
writedef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *aicpuctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
aicpuctx->set_op_index(0); | |||
aicpuctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AICPU)); | |||
aicpuctx->set_software_ctx_type(0); | |||
InitAicpuCtxCtx(aicpuctx); | |||
domi::FftsPlusAicpuCtxDef *aicpudef = aicpuctx->mutable_aicpu_ctx(); | |||
InitAicpuCtxCtx(aicpudef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
aicpuctx->aicpu_ctx().add_successor_context_id(1); | |||
aicpudef->add_successor_context_id(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
aicpuctx->aicpu_ctx().add_user_data(1); | |||
aicpudef->add_user_data(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *datactx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
datactx->set_op_index(0); | |||
datactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_FLUSH_DATA)); | |||
datactx->set_software_ctx_type(0); | |||
InitDataCtx(datactx); | |||
domi::FftsPlusDataCtxDef *datadef = datactx->mutable_data_ctx(); | |||
InitDataCtx(datadef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
datactx->data_ctx().add_successor_list(1); | |||
datadef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *caseswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
caseswitchctx->set_op_index(0); | |||
caseswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | |||
caseswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | |||
InitCaseSwitchCtx(caseswitchctx); | |||
domi::FftsPlusCaseSwitchCtxDef *caseswitchdef = caseswitchctx->mutable_case_switch_ctx(); | |||
InitCaseSwitchCtx(caseswitchdef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
caseswitchctx->case_switch_ctx().add_successor_list(1); | |||
caseswitchdef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
domi::FftsPlusCtxDef *candswitchctx = ffts_plus_task_def->add_ffts_plus_ctx(); | |||
candswitchctx->set_op_index(0); | |||
candswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | |||
candswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_COND_SWITCH)); | |||
InitCondSwitchCtx(candswitchctx); | |||
domi::FftsPlusCondSwitchCtxDef *candswitchdef = candswitchctx->mutable_cond_switch_ctx(); | |||
InitCondSwitchCtx(candswitchdef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
candswitchctx->cond_switch_ctx().add_true_successor_list(1); | |||
candswitchdef->add_true_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
candswitchctx->cond_switch_ctx().add_false_successor_list(1); | |||
candswitchdef->add_false_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
} | |||
@@ -625,10 +624,11 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx_ex) { | |||
casesdefaultctx->set_op_index(0); | |||
casesdefaultctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD)); | |||
casesdefaultctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH)); | |||
InitCaseDefaultCtx(casesdefaultctx); | |||
domi::FftsPlusCaseDefaultCtxDef *casesdefaultdef = casesdefaultctx->mutable_case_default_ctx(); | |||
InitCaseDefaultCtx(casesdefaultdef); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED); | |||
casesdefaultctx->case_default_ctx().add_successor_list(1); | |||
casesdefaultdef->add_successor_list(1); | |||
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS); | |||
} | |||
// test FftsPlusTaskInfo UpdateArgs | |||
@@ -0,0 +1,63 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <memory> | |||
#define protected public | |||
#define private public | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "graph/ge_context.h" | |||
#undef protected | |||
#undef private | |||
namespace ge { | |||
class UtestGraphVarManagerTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) { | |||
size_t total_mem_size = 0; | |||
Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size); | |||
EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) { | |||
const map<string, string> options{}; | |||
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); | |||
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) { | |||
const map<string, string> options{{"ge.graphMemoryMaxSize", "536870912"}}; | |||
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); | |||
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f))); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) { | |||
const map<string, string> options{{"ge.variableMemoryMaxSize", "536870912"}}; | |||
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options); | |||
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f))); | |||
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2)); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
} // namespace ge |