Browse Source

Merge branch '9e3c73ed'

pull/1925/head
liudingyan 3 years ago
parent
commit
ff970adb69
13 changed files with 226 additions and 330 deletions
  1. +2
    -6
      CMakeLists.txt
  2. +10
    -5
      cmake/external_libs/gflags.cmake
  3. +5
    -3
      ge/CMakeLists.txt
  4. +7
    -7
      ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc
  5. +41
    -11
      ge/graph/manager/graph_var_manager.cc
  6. +3
    -0
      ge/graph/manager/graph_var_manager.h
  7. +14
    -217
      ge/hybrid/model/hybrid_model_builder.cc
  8. +0
    -5
      ge/hybrid/model/hybrid_model_builder.h
  9. +2
    -4
      ge/offline/CMakeLists.txt
  10. +6
    -0
      tests/depends/runtime/src/runtime_stub.cc
  11. +1
    -0
      tests/ut/ge/CMakeLists.txt
  12. +72
    -72
      tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc
  13. +63
    -0
      tests/ut/ge/graph/manager/graph_var_manager_unittest.cc

+ 2
- 6
CMakeLists.txt View File

@@ -88,11 +88,9 @@ else ()
find_module(hccl libhccl.so ${GE_LIB_PATH})
find_module(adump_server libadump_server.a ${GE_LIB_PATH})
find_module(runtime libruntime.so ${GE_LIB_PATH})
find_module(runtime_compile libruntime_compile.so ${GE_LIB_PATH})
find_module(resource libresource.so ${GE_LIB_PATH})
find_module(ascend_hal_stub libascend_hal.so ${GE_LIB_PATH})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${GE_LIB_PATH})
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH})
else()
find_module(slog libalog.so ${ASCEND_ATC_DIR})
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR})
@@ -108,7 +106,6 @@ else ()
elseif(PLATFORM STREQUAL "inference")
find_module(adump_server libadump_server.a ${ASCEND_ACL_DIR})
find_module(runtime libruntime.so ${ASCEND_ACL_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR})
if(PRODUCT STREQUAL "flr3")
elseif(PRODUCT STREQUAL "flr1")
@@ -120,10 +117,9 @@ else ()
endif()
elseif(PLATFORM STREQUAL "all")
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_RUNTIME_DIR})
find_module(runtime libruntime.so ${ASCEND_ATC_DIR})
find_module(msprofiler_fwk_ext libmsprofiler_fwk.a ${ASCEND_RUNTIME_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR})
find_module(runtime_compile libruntime_compile.so ${ASCEND_ATC_DIR})
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_ATC_DIR}/stub)
find_module(msprofiler_ext libmsprofiler.a ${ASCEND_ACL_DIR})
else()
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!")


+ 10
- 5
cmake/external_libs/gflags.cmake View File

@@ -10,12 +10,17 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.")
endif()

if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz")
set(MD5 "")
if (GE_PB_PKG)
set(REQ_URL "${GE_PB_PKG}/libs/gflags/v2.2.2.tar.gz")
set(MD5 "1a865b93bacfa963201af3f75b7bd64c")
else()
set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz")
set(MD5 "")
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz")
set(MD5 "")
else()
set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz")
set(MD5 "1a865b93bacfa963201af3f75b7bd64c")
endif ()
endif ()

set (gflags_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -Dgoogle=ascend_private")


+ 5
- 3
ge/CMakeLists.txt View File

@@ -112,6 +112,8 @@ set(EXECUTOR_SRC_LIST
"common/dump/dump_op.cc"
"common/dump/exception_dumper.cc"
"common/dump/opdebug_register.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/profiling/ge_profiling.cc"
"common/profiling/profiling_manager.cc"
"executor/ge_executor.cc"
@@ -259,6 +261,8 @@ set(EXECUTOR_SRC_LIST
set(COMPILER_SRC_LIST
"analyzer/analyzer.cc"
"common/dump/dump_op.cc"
"common/ge/op_tiling_manager.cc"
"common/ge/plugin_manager.cc"
"common/helper/model_cache_helper.cc"
"common/profiling/profiling_manager.cc"
"engine_manager/dnnengine_manager.cc"
@@ -619,7 +623,6 @@ target_compile_definitions(ge_compiler PRIVATE
REUSE_MEMORY=1
FMK_SUPPORT_DUMP
FMK_HOST_INFER
COMPILE_OMG_PACKAGE
google=ascend_private
FUNC_VISIBILITY
$<$<STREQUAL:${ENABLE_OPEN_SRC},True>:ONLY_COMPILE_OPEN_SRC>
@@ -681,8 +684,7 @@ target_link_libraries(ge_compiler PRIVATE
c_sec
error_manager
slog
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>>
$<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>>
runtime
opt_feature
-Wl,--as-needed
json


+ 7
- 7
ge/graph/load/model_manager/task_info/ffts_plus_task_info.cc View File

@@ -350,7 +350,7 @@ Status FftsPlusTaskInfo::InitAicAivCtx(const domi::FftsPlusAicAivCtxDef &ctx_def
i_cache_prefetch_cnt_2));
ctx->tailTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_task_start_pc) & 0XFFFFFFFF);
ctx->tailTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_task_start_pc) >> 32) & 0X0000FFFF);
uint32_t i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2);
uint32_t i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2);
ctx->icachePrefetchCnt = static_cast<uint16_t>(i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111

if (ctx_def.src_slot_size() != kSrcSlotNum) {
@@ -526,8 +526,7 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c
ctx->tailAicTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) & 0XFFFFFFFF);
ctx->tailAicTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aic_task_start_pc) >> 32) &
0X0000FFFF);
uint32_t aic_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2);
// TODO
uint32_t aic_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_1, i_cache_prefetch_cnt_2);
ctx->icachePrefetchCnt = static_cast<uint16_t>(aic_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111

uint32_t i_cache_prefetch_cnt_3;
@@ -545,9 +544,10 @@ Status FftsPlusTaskInfo::InitMixAicAivCtx(const domi::FftsPlusMixAicAivCtxDef &c
ctx->tailAivTaskStartPcL = static_cast<uint32_t>(reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) & 0XFFFFFFFF);
ctx->tailAivTaskStartPcH = static_cast<uint16_t>((reinterpret_cast<uintptr_t>(tail_aiv_task_start_pc) >> 32) &
0X0000FFFF);
uint32_t aiv_i_cache_prefetch_cnt = std::max(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4);
uint32_t aiv_i_cache_prefetch_cnt = std::min(i_cache_prefetch_cnt_3, i_cache_prefetch_cnt_4);
// TODO
ctx->icachePrefetchCnt = static_cast<uint16_t>(aiv_i_cache_prefetch_cnt & 0X0000001F); // 5 bits, 0001,1111
ctx->icachePrefetchCnt = static_cast<uint16_t>(
std::min(aic_i_cache_prefetch_cnt, aiv_i_cache_prefetch_cnt) & 0X0000001F); // 5 bits, 0001,1111

if (ctx_def.src_slot_size() != kSrcSlotNum) {
REPORT_INNER_ERROR("E19999", "Size of src_slot in FftsPlusMixAicAivCtxDef should be %d, but %d exactly",
@@ -913,11 +913,11 @@ void FftsPlusTaskInfo::SetAdditionalDatatoCtx(const domi::FftsPlusTaskDef &task_

Status FftsPlusTaskInfo::UpdateMixAicAivCtxParam(const domi::FftsPlusMixAicAivCtxDef &ctx_def, size_t ctx_idx) {
if (ctx_additional_data_.count(ctx_idx) == 0) {
GELOGD("ctx idx:%d not in ctx additional data");
GELOGD("ctx idx:%zu not in ctx additional data");
return SUCCESS;
}
if (ctx_additional_data_[ctx_idx].count(kModeInArgsFirstField) == 0) {
GELOGD("ctx idx:%d need not to save mode in args first field");
GELOGD("ctx idx:%zu need not to save mode in args first field");
return SUCCESS;
}
if (rtApp_addr_ == 0) {


+ 41
- 11
ge/graph/manager/graph_var_manager.cc View File

@@ -20,6 +20,7 @@
#include "graph/manager/graph_mem_manager.h"
#include "graph/manager/trans_var_data_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/ge_context.h"

using std::map;
using std::string;
@@ -767,25 +768,52 @@ Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &grap
return var_resource_->GetChangedGraphId(var_name, graph_id);
}

Status VarManager::GetTotalMemorySize(size_t &total_mem_size) {
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
return RT_FAILED;
}
size_t free_mem = 0;
rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret);
GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
return RT_FAILED;
}
return SUCCESS;
}

Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
if (it == options.end()) {
graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
} else {
string graph_memory_manager_malloc_max_size = it->second;
size_t total_mem_size = 0;
GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size));
GEEVENT("Total memory size is %zu", total_mem_size);

graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio);
var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio);

auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE);
if (it1 != options.end()) {
string graph_memory_manager_malloc_max_size = it1->second;
ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
if (ret != SUCCESS) {
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
return ge::GE_GRAPH_OPTIONS_INVALID;
}
GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
}

it = options.find(VARIABLE_MEMORY_MAX_SIZE);
if (it == options.end()) {
var_mem_max_size_ = kMemoryVarManagerMallocSize;
} else {
string memory_var_manager_malloc_size = it->second;
auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE);
if (it2 != options.end()) {
string memory_var_manager_malloc_size = it2->second;
ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
if (ret != SUCCESS) {
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
@@ -793,6 +821,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
}
}

GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_);

var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
if (var_mem_logic_base_ > kMaxMemorySize) {
REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",


+ 3
- 0
ge/graph/manager/graph_var_manager.h View File

@@ -43,6 +43,8 @@ const size_t kMaxMemorySize = 256UL * 1024UL * 1024UL * 1024UL;
const char kEnvGeuseStaticMemory[] = "GE_USE_STATIC_MEMORY";
const uint64_t kSessionMemAlignSize = 512;
const size_t kSessionMemAlignUnit = 2;
const double kGraphMemoryManagerMallocRatio = 26.0 / 32.0;
const double kVarMemoryManagerMallocRatio = 5.0 / 32.0;

enum MemStatus {
NORMAL = 0,
@@ -301,6 +303,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {
mutable std::recursive_mutex mutex_;

Status ParseMemoryMallocSize(std::string &memory_size, size_t &my_size);
Status GetTotalMemorySize(size_t &total_mem_size);
};

class VarManagerPool {


+ 14
- 217
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -60,7 +60,6 @@ const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE";
const char *const kForceInfershape = "_force_infershape_when_running";

const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH };
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP };
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };

Status SetOutputNameAttr(ComputeGraph &graph) {
@@ -519,170 +518,6 @@ Status HybridModelBuilder::UpdateAnchorStatus(const NodePtr &node) {
return SUCCESS;
}

Status HybridModelBuilder::DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor,
const InDataAnchorPtr &in_data_anchor) {
GE_CHK_GRAPH_STATUS_RET(out_data_anchor->Unlink(in_data_anchor),
"[Invoke][Unlink] failed to unlink %s:%d from %s:%d",
out_data_anchor->GetOwnerNode()->GetName().c_str(), out_data_anchor->GetIdx(),
in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx());

GELOGD("Succeeded in unlinking %s:%d from %s:%d",
out_data_anchor->GetOwnerNode()->GetName().c_str(),
out_data_anchor->GetIdx(),
in_data_anchor->GetOwnerNode()->GetName().c_str(),
in_data_anchor->GetIdx());
return SUCCESS;
}

Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor) {
GE_CHK_GRAPH_STATUS_RET(out_data_anchor->LinkTo(in_data_anchor), "[Invoke][LinkTo]Failed to link %s:%d to %s:%d",
out_data_anchor->GetOwnerNode()->GetName().c_str(),
out_data_anchor->GetIdx(),
in_data_anchor->GetOwnerNode()->GetName().c_str(),
in_data_anchor->GetIdx());

GELOGD("Succeeded in linking %s:%d to %s:%d",
out_data_anchor->GetOwnerNode()->GetName().c_str(),
out_data_anchor->GetIdx(),
in_data_anchor->GetOwnerNode()->GetName().c_str(),
in_data_anchor->GetIdx());
return SUCCESS;
}

Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) {
const auto &wrapped_node = graph.GetParentNode();
std::set<NodePtr> root_nodes;
for (const auto &node : graph.GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (node->GetType() != DATA_TYPE) {
if (node->GetInDataNodes().empty()) {
root_nodes.emplace(node);
}

continue;
}

auto data_op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(data_op_desc);

uint32_t parent_index = 0;
if (!AttrUtils::GetInt(data_op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGE(FAILED, "[Invoke][GetInt] failed, node:[%s] attr:[%s]",
data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
REPORT_CALL_ERROR("E19999", "GetInt failed, node:[%s] attr:[%s]",
data_op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
return FAILED;
}

auto wrapped_node_in_anchor = wrapped_node->GetInDataAnchor(parent_index);
GE_CHECK_NOTNULL(wrapped_node_in_anchor);
auto src_out_anchor = wrapped_node_in_anchor->GetPeerOutAnchor();
if (src_out_anchor == nullptr || src_out_anchor->GetOwnerNode() == nullptr) {
continue;
}
wrapped_node_in_anchor->UnlinkAll();

// link src to outputs of DataNode
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
GE_CHECK_NOTNULL(out_data_anchor);
for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto dst_node = peer_in_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
const auto in_nodes = dst_node->GetInDataNodes();
if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) {
root_nodes.emplace(dst_node);
}
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor));
GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor));
}
}
}

// transfer in control edges to all root nodes
for (auto &root_node : root_nodes) {
auto in_nodes = root_node->GetInAllNodes();
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end());
for (auto &in_control_node : wrapped_node->GetInControlNodes()) {
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) {
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str());
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor());
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor());
}
}
}

wrapped_node->GetInControlAnchor()->UnlinkAll();
return SUCCESS;
}

Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) {
const auto &parent_node = graph.GetParentNode();
const NodePtr &net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
GELOGD("Graph has no netoutput no need to merge");
return SUCCESS;
}
const auto &net_output_desc = net_output_node->GetOpDesc();
GE_CHECK_NOTNULL(net_output_desc);

auto all_in_nodes = net_output_node->GetInAllNodes();
auto all_out_nodes = parent_node->GetOutAllNodes();
net_output_node->GetInControlAnchor()->UnlinkAll();
parent_node->GetOutControlAnchor()->UnlinkAll();

for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) {
auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(src_out_anchor);
GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNode());
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(src_out_anchor, in_data_anchor));

auto index = in_data_anchor->GetIdx();
auto input_desc = net_output_desc->MutableInputDesc(index);
if (input_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "[Invoke][MutableInputDesc][%s] Failed to get input desc[%d]",
net_output_desc->GetName().c_str(), index);
REPORT_CALL_ERROR("E19999", "[%s] Failed to get input desc[%d].", net_output_desc->GetName().c_str(), index);
return INTERNAL_ERROR;
}

uint32_t parent_index = 0;
if (!AttrUtils::GetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found.",
graph.GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str());
continue;
}

const OutDataAnchorPtr &parent_out_anchor = parent_node->GetOutDataAnchor(parent_index);
GE_CHECK_NOTNULL(parent_out_anchor);
for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) {
if (dst_in_anchor == nullptr) {
continue;
}

GE_CHECK_NOTNULL(dst_in_anchor->GetOwnerNode());
GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(parent_out_anchor, dst_in_anchor));
GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, dst_in_anchor));
}
}

// transfer out control edges
std::set<NodePtr> in_node_set(all_in_nodes.begin(), all_in_nodes.end());
std::set<NodePtr> out_node_set(all_out_nodes.begin(), all_out_nodes.end());
for (auto &src_node : in_node_set) {
GELOGD("[%s] process in node.", src_node->GetName().c_str());
auto out_nodes = src_node->GetOutAllNodes();
std::set<NodePtr> node_set(out_nodes.begin(), out_nodes.end());
for (auto &dst_node : out_node_set) {
if (node_set.count(dst_node) == 0) {
src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor());
GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str());
}
}
}

return SUCCESS;
}

Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) {
merged_graph = MakeShared<ComputeGraph>("MergedGraph");
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag());
@@ -716,9 +551,21 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG
}
}
}
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph),

const auto &filter = [](const ComputeGraphPtr &graph) {
const auto &parent_node = graph->GetParentNode();
if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) {
return false;
}
if ((parent_node->GetType() != PARTITIONEDCALL) ||
(parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) {
return false;
}
return graph->GetGraphUnknownFlag();
};
GE_CHK_GRAPH_STATUS_RET(GraphUtils::UnfoldSubgraph(subgraph, filter),
"[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph.",
subgraph->GetName().c_str());
subgraph->GetName().c_str())
}

// invoke before adding subgraphs. in case modify node id in known-shaped subgraphs.
@@ -744,56 +591,6 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeG
return SUCCESS;
}

Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph,
ComputeGraphPtr &parent_graph,
ComputeGraph &sub_graph) {
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);

GE_CHK_STATUS_RET(MergeInputNodes(sub_graph),
"[Invoke][MergeInputNodes][%s] Failed to merge data nodes for subgraph",
sub_graph.GetName().c_str());
GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph),
"[Invoke][MergeNetOutputNode][%s] Failed to merge net output nodes for subgraph",
sub_graph.GetName().c_str());
GELOGD("[%s] Done merging subgraph inputs and outputs successfully", sub_graph.GetName().c_str());

for (auto &sub_node : sub_graph.GetDirectNode()) {
auto sub_op_type = sub_node->GetType();
if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) {
continue;
}
if (sub_op_type == PARTITIONEDCALL) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex);
GE_CHECK_NOTNULL(sub_sub_graph);
if (sub_sub_graph->GetGraphUnknownFlag()) {
GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph),
"[Invoke][UnfoldSubgraph][%s] Failed to merge subgraph",
sub_sub_graph->GetName().c_str());
continue;
}
}

if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i);
GE_CHECK_NOTNULL(sub_sub_graph);
sub_sub_graph->SetParentGraph(parent_graph);
}
}
parent_graph->AddNode(sub_node);
GELOGD("[%s::%s] added to parent graph: [%s].",
sub_graph.GetName().c_str(),
sub_node->GetName().c_str(),
parent_graph->GetName().c_str());
sub_node->SetOwnerComputeGraph(parent_graph);
}

GELOGD("[%s] Done merging subgraph. remove it from root graph", sub_graph.GetName().c_str());
root_graph->RemoveSubgraph(sub_graph.GetName());
return SUCCESS;
}

Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item,
const NodeItem &node_item,
bool is_root_graph) {


+ 0
- 5
ge/hybrid/model/hybrid_model_builder.h View File

@@ -39,16 +39,11 @@ class HybridModelBuilder {

private:
static Status UpdateAnchorStatus(const NodePtr &node);
static Status DoUnlinkDataAnchors(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &in_data_anchor);
static Status DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, InDataAnchorPtr &in_data_anchor);
static NodePtr GetPeerNode(const InDataAnchorPtr &in_data_anchor);
static Status GetParentNodeOutputIndex(const OpDesc &op_desc, int index, uint32_t &out_index);
static Status GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, NodePtr &peer_node, int &peer_out_index);
static Status HandleDtString(const GeTensor &tensor, void *var_addr);
static Status MergeInputNodes(ComputeGraph &compute_graph);
static Status MergeNetOutputNode(ComputeGraph &compute_graph);
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph);
static Status BuildInputMapping(GraphItem &graph_item,
std::vector<NodeItem *> &data_nodes,
bool is_root_graph);


+ 2
- 4
ge/offline/CMakeLists.txt View File

@@ -22,7 +22,6 @@ target_compile_options(atc_atc.bin PRIVATE

target_compile_definitions(atc_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY
@@ -48,6 +47,7 @@ target_include_directories(atc_atc.bin PRIVATE

target_link_options(atc_atc.bin PRIVATE
-Wl,-Bsymbolic
-Wl,-rpath-link,${ASCEND_ATC_DIR}/stub
)

target_link_libraries(atc_atc.bin PRIVATE
@@ -62,8 +62,7 @@ target_link_libraries(atc_atc.bin PRIVATE
parser_common
gflags
json
$<$<NOT:$<BOOL:${ENABLE_OPEN_SRC}>>:$<BUILD_INTERFACE:runtime>>
$<$<BOOL:${ENABLE_OPEN_SRC}>:$<BUILD_INTERFACE:runtime_compile>>
runtime
slog
static_mmpa
-lrt
@@ -92,7 +91,6 @@ target_compile_options(fwk_atc.bin PRIVATE

target_compile_definitions(fwk_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
LOG_CPP
FUNC_VISIBILITY


+ 6
- 0
tests/depends/runtime/src/runtime_stub.cc View File

@@ -193,6 +193,12 @@ rtError_t rtMemGetInfo(size_t *free, size_t *total) {
return RT_ERROR_NONE;
}

rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *free, size_t *total) {
*free = 512UL * 1024UL * 1024UL;
*total = 1024UL * 1024UL * 1024UL;
return RT_ERROR_NONE;
}

rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag) { return RT_ERROR_NONE; }

rtError_t rtMemFreeManaged(void *ptr) { return RT_ERROR_NONE; }


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -692,6 +692,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/manager/run_graph_unittest.cc"
"graph/partition/dynamic_shape_partition_unittest.cc"
"graph/manager/graph_manager_unittest.cc"
"graph/manager/graph_var_manager_unittest.cc"
"graph/optimize/mem_rw_conflict_optimize_unittest.cc"
"graph/optimize/graph_optimize_unittest.cc"
"session/omg_omg_unittest.cc"


+ 72
- 72
tests/ut/ge/graph/load/ffts_plus_task_info_unittest.cc View File

@@ -79,13 +79,12 @@ public:
additionaldata1->add_context_id(5);
}

void InitAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_aic_aiv_ctx();
void InitAicAivCtx(domi::FftsPlusAicAivCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
ctxdef->set_pred_cnt(1);
for (int i = 0; i < RT_CTX_SUCCESSOR_NUM; ++i) {
for (int i = 1; i < RT_CTX_SUCCESSOR_NUM; ++i) {
ctxdef->add_successor_list(1); // 16 bits, len = 26
}
ctxdef->set_stat(1);
@@ -113,8 +112,7 @@ public:
}
}

void InitMixAicAivCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusMixAicAivCtxDef *ctxdef = fftsplusctxdef->mutable_mix_aic_aiv_ctx();
void InitMixAicAivCtx(domi::FftsPlusMixAicAivCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -153,8 +151,7 @@ public:
}
}

void InitSdmaCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusSdmaCtxDef *ctxdef = fftsplusctxdef->mutable_sdma_ctx();
void InitSdmaCtx(domi::FftsPlusSdmaCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -184,8 +181,7 @@ public:
ctxdef->set_tail_data_len(1);
}

void InitNotifyCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusNotifyCtxDef *ctxdef = fftsplusctxdef->mutable_notify_ctx();
void InitNotifyCtx(domi::FftsPlusNotifyCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -201,8 +197,7 @@ public:
ctxdef->set_notify_id_base(1);
}
void InitWriteValueCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusWriteValueCtxDef *ctxdef = fftsplusctxdef->mutable_write_value_ctx();
void InitWriteValueCtx(domi::FftsPlusWriteValueCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -227,8 +222,7 @@ public:
ctxdef->add_write_value(1);
}
void InitAicpuCtxCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusAicpuCtxDef *ctxdef = fftsplusctxdef->mutable_aicpu_ctx();
void InitAicpuCtxCtx(domi::FftsPlusAicpuCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -260,8 +254,7 @@ public:
ctxdef->set_task_param_offset(32);
}
void InitDataCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusDataCtxDef *ctxdef = fftsplusctxdef->mutable_data_ctx();
void InitDataCtx(domi::FftsPlusDataCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_cnt_init(1);
@@ -293,8 +286,7 @@ public:
ctxdef->set_tail_stride_inner(1);
}
void InitAtStartCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusAtStartCtxDef *ctxdef = fftsplusctxdef->mutable_at_start_ctx();
void InitAtStartCtx(domi::FftsPlusAtStartCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(1);
ctxdef->set_pred_cnt_init(1);
@@ -309,8 +301,7 @@ public:
ctxdef->set_thread_window_size(1);
}
void InitAtEndCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusAtEndCtxDef *ctxdef = fftsplusctxdef->mutable_at_end_ctx();
void InitAtEndCtx(domi::FftsPlusAtEndCtxDef *ctxdef) {
ctxdef->set_at_start_slot_num(12);
ctxdef->set_out_label_slot_num(12);
ctxdef->set_aten(1);
@@ -325,8 +316,7 @@ public:
ctxdef->set_thread_id(1);
}
void InitLabelCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusLabelCtxDef *ctxdef = fftsplusctxdef->mutable_label_ctx();
void InitLabelCtx(domi::FftsPlusLabelCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_pred_cnt_init(1);
ctxdef->set_pred_cnt(1);
@@ -335,8 +325,7 @@ public:
}
}
void InitCaseSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusCaseSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_case_switch_ctx();
void InitCaseSwitchCtx(domi::FftsPlusCaseSwitchCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(32);
ctxdef->set_start_label_id(32);
@@ -366,8 +355,7 @@ public:
ctxdef->set_load_addr1_offset(32);
}
void InitCaseDefaultCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusCaseDefaultCtxDef *ctxdef = fftsplusctxdef->mutable_case_default_ctx();
void InitCaseDefaultCtx(domi::FftsPlusCaseDefaultCtxDef *ctxdef) {
ctxdef->set_successor_num(26);
ctxdef->set_aten(32);
ctxdef->set_start_label_id(1);
@@ -379,8 +367,7 @@ public:
}
}
void InitCondSwitchCtx(domi::FftsPlusCtxDef *fftsplusctxdef) {
domi::FftsPlusCondSwitchCtxDef *ctxdef = fftsplusctxdef->mutable_cond_switch_ctx();
void InitCondSwitchCtx(domi::FftsPlusCondSwitchCtxDef *ctxdef) {
ctxdef->set_true_successor_num(12);
ctxdef->set_false_successor_num(14);
ctxdef->set_aten(32);
@@ -444,35 +431,38 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_software_ctx) {
InitTaskSQEInfo(ffts_plus_task_def);
InitTaskAdditionalDataInfo(ffts_plus_task_def);

domi::FftsPlusCtxDef *startctx = ffts_plus_task_def->add_ffts_plus_ctx();
startctx->set_op_index(0);
startctx->set_hardware_ctx_type(0);
startctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START));
InitAtStartCtx(startctx);
domi::FftsPlusCtxDef *fftsplusstartctx = ffts_plus_task_def->add_ffts_plus_ctx();
fftsplusstartctx->set_op_index(0);
fftsplusstartctx->set_hardware_ctx_type(0);
fftsplusstartctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_START));
domi::FftsPlusAtStartCtxDef *startctxdef = fftsplusstartctx->mutable_at_start_ctx();
InitAtStartCtx(startctxdef);

EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED);
startctx->at_start_ctx().add_successor_list(1);
startctxdef->add_successor_list(1);
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *endctx = ffts_plus_task_def->add_ffts_plus_ctx();
endctx->set_op_index(0);
endctx->set_hardware_ctx_type(0);
endctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END));
InitAtEndCtx(endctx);
domi::FftsPlusCtxDef *fftsplusendctx = ffts_plus_task_def->add_ffts_plus_ctx();
fftsplusendctx->set_op_index(0);
fftsplusendctx->set_hardware_ctx_type(0);
fftsplusendctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_AT_END));
domi::FftsPlusAtEndCtxDef *endctxdef = fftsplusendctx->mutable_at_end_ctx();
InitAtEndCtx(endctxdef);

EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED);
endctx->at_end_ctx().add_succ_at_start_slot(1);
endctxdef->add_succ_at_start_slot(1);
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED);
endctx->at_end_ctx().add_succ_out_label_slot(1);
endctxdef->add_succ_out_label_slot(1);
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *labelctx = ffts_plus_task_def->add_ffts_plus_ctx();
labelctx->set_op_index(0);
labelctx->set_hardware_ctx_type(0);
labelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL));
InitLabelCtx(labelctx);
domi::FftsPlusCtxDef *fftspluslabelctx = ffts_plus_task_def->add_ffts_plus_ctx();
fftspluslabelctx->set_op_index(0);
fftspluslabelctx->set_hardware_ctx_type(0);
fftspluslabelctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_LABEL));
domi::FftsPlusLabelCtxDef *labelctxdef = fftsplusctxdef->mutable_label_ctx();
InitLabelCtx(labelctxdef);
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), FAILED);
labelctx->label_ctx().add_successor_list(1);
labelctxdef->add_successor_list(1);
EXPECT_EQ(ffts_plus_task_info.Init(task_def, &davinci_model), SUCCESS);
}

@@ -501,102 +491,111 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx) {
aicaivctx->set_op_index(0);
aicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AIV));
aicaivctx->set_software_ctx_type(0);
InitAicAivCtx(aicaivctx);
domi::FftsPlusAicAivCtxDef *aicaivdef = aicaivctx->mutable_aic_aiv_ctx();
InitAicAivCtx(aicaivdef);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
aicaivctx->aic_aiv_ctx().add_successor_list(1);
aicaivdef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
aicaivctx->aic_aiv_ctx().add_kernel_name("aivtest");
aicaivdef->add_kernel_name("aivtest");
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
aicaivctx->aic_aiv_ctx().add_src_slot(1);
aicaivdef->add_src_slot(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);
domi::FftsPlusCtxDef *mixaicaivctx = ffts_plus_task_def->add_ffts_plus_ctx();
mixaicaivctx->set_op_index(0);
mixaicaivctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_MIX_AIC));
mixaicaivctx->set_software_ctx_type(0);
InitMixAicAivCtx(mixaicaivctx);
domi::FftsPlusMixAicAivCtxDef *mixctxdef = mixaicaivctx->mutable_mix_aic_aiv_ctx();
InitMixAicAivCtx(mixctxdef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
mixaicaivctx->mix_aic_aiv_ctx().add_successor_list(1);
mixctxdef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
mixaicaivctx->mix_aic_aiv_ctx().add_kernel_name("mixaiv");
mixctxdef->add_kernel_name("mixaiv");
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
mixaicaivctx->mix_aic_aiv_ctx().add_src_slot(1);
mixctxdef->add_src_slot(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *notifyctx = ffts_plus_task_def->add_ffts_plus_ctx();
notifyctx->set_op_index(0);
notifyctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_NOTIFY_WAIT));
notifyctx->set_software_ctx_type(0);
InitNotifyCtx(notifyctx);
domi::FftsPlusNotifyCtxDef *notifydef = notifyctx->mutable_notify_ctx();
InitNotifyCtx(notifydef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
notifyctx->notify_ctx().add_successor_list(1);
notifydef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *sdmactx = ffts_plus_task_def->add_ffts_plus_ctx();
sdmactx->set_op_index(0);
sdmactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_SDMA));
sdmactx->set_software_ctx_type(0);
InitSdmaCtx(sdmactx);
domi::FftsPlusSdmaCtxDef *smdadef = sdmactx->mutable_sdma_ctx();
InitSdmaCtx(smdadef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
sdmactx->sdma_ctx().add_successor_list(1);
smdadef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *writevalctx = ffts_plus_task_def->add_ffts_plus_ctx();
writevalctx->set_op_index(0);
writevalctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_WRITE_VALUE));
writevalctx->set_software_ctx_type(0);
InitWriteValueCtx(writevalctx);
domi::FftsPlusWriteValueCtxDef *writedef = writevalctx->mutable_write_value_ctx();
InitWriteValueCtx(writedef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
writevalctx->write_value_ctx().add_successor_list(1);
writedef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *aicpuctx = ffts_plus_task_def->add_ffts_plus_ctx();
aicpuctx->set_op_index(0);
aicpuctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_AICPU));
aicpuctx->set_software_ctx_type(0);
InitAicpuCtxCtx(aicpuctx);
domi::FftsPlusAicpuCtxDef *aicpudef = aicpuctx->mutable_aicpu_ctx();
InitAicpuCtxCtx(aicpudef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
aicpuctx->aicpu_ctx().add_successor_context_id(1);
aicpudef->add_successor_context_id(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
aicpuctx->aicpu_ctx().add_user_data(1);
aicpudef->add_user_data(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *datactx = ffts_plus_task_def->add_ffts_plus_ctx();
datactx->set_op_index(0);
datactx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_FLUSH_DATA));
datactx->set_software_ctx_type(0);
InitDataCtx(datactx);
domi::FftsPlusDataCtxDef *datadef = datactx->mutable_data_ctx();
InitDataCtx(datadef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
datactx->data_ctx().add_successor_list(1);
datadef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *caseswitchctx = ffts_plus_task_def->add_ffts_plus_ctx();
caseswitchctx->set_op_index(0);
caseswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD));
caseswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH));
InitCaseSwitchCtx(caseswitchctx);
domi::FftsPlusCaseSwitchCtxDef *caseswitchdef = caseswitchctx->mutable_case_switch_ctx();
InitCaseSwitchCtx(caseswitchdef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
caseswitchctx->case_switch_ctx().add_successor_list(1);
caseswitchdef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);

domi::FftsPlusCtxDef *candswitchctx = ffts_plus_task_def->add_ffts_plus_ctx();
candswitchctx->set_op_index(0);
candswitchctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD));
candswitchctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_COND_SWITCH));
InitCondSwitchCtx(candswitchctx);
domi::FftsPlusCondSwitchCtxDef *candswitchdef = candswitchctx->mutable_cond_switch_ctx();
InitCondSwitchCtx(candswitchdef);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
candswitchctx->cond_switch_ctx().add_true_successor_list(1);
candswitchdef->add_true_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
candswitchctx->cond_switch_ctx().add_false_successor_list(1);
candswitchdef->add_false_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);
}

@@ -625,10 +624,11 @@ TEST_F(UtestFftsPlusTaskInfo, success_ffts_plus_task_info_hardware_ctx_ex) {
casesdefaultctx->set_op_index(0);
casesdefaultctx->set_hardware_ctx_type(static_cast<uint32_t>(RT_HW_CTX_TYPE_LOAD));
casesdefaultctx->set_software_ctx_type(static_cast<uint32_t>(RT_SOFT_CTX_TYPE_CASE_SWITCH));
InitCaseDefaultCtx(casesdefaultctx);
domi::FftsPlusCaseDefaultCtxDef *casesdefaultdef = casesdefaultctx->mutable_case_default_ctx();
InitCaseDefaultCtx(casesdefaultdef);

EXPECT_EQ(task_info.Init(task_def, &davinci_model), FAILED);
casesdefaultctx->case_default_ctx().add_successor_list(1);
casesdefaultdef->add_successor_list(1);
EXPECT_EQ(task_info.Init(task_def, &davinci_model), SUCCESS);
}
// test FftsPlusTaskInfo UpdateArgs


+ 63
- 0
tests/ut/ge/graph/manager/graph_var_manager_unittest.cc View File

@@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <memory>

#define protected public
#define private public
#include "graph/manager/graph_var_manager.h"
#include "graph/ge_context.h"
#undef protected
#undef private

namespace ge {
class UtestGraphVarManagerTest : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};

TEST_F(UtestGraphVarManagerTest, test_get_total_memory_size) {
size_t total_mem_size = 0;
Status ret = VarManager::Instance(0)->GetTotalMemorySize(total_mem_size);
EXPECT_EQ(total_mem_size, 1024UL * 1024UL * 1024UL);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_no_related_option) {
const map<string, string> options{};
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options);
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f)));
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f)));
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_graph_mem_max_size) {
const map<string, string> options{{"ge.graphMemoryMaxSize", "536870912"}};
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options);
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2));
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (5.0f / 32.0f)));
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphVarManagerTest, test_set_memory_malloc_size_with_user_specify_var_mem_max_size) {
const map<string, string> options{{"ge.variableMemoryMaxSize", "536870912"}};
Status ret = VarManager::Instance(0)->SetMemoryMallocSize(options);
EXPECT_EQ(VarManager::Instance(0)->graph_mem_max_size_, floor(1024UL * 1024UL * 1024UL * (26.0f / 32.0f)));
EXPECT_EQ(VarManager::Instance(0)->var_mem_max_size_, floor(1024UL * 1024UL * 1024UL / 2));
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save