Browse Source

Merge branch 'development' of gitee.com:mindspore/graphengine into single_op

pull/588/head
储星 Gitee 4 years ago
parent
commit
4f91e07621
62 changed files with 927 additions and 994 deletions
  1. +2
    -1
      ge/CMakeLists.txt
  2. +2
    -1
      ge/common/debug/memory_dumper.cc
  3. +7
    -1
      ge/common/ge/plugin_manager.cc
  4. +2
    -1
      ge/common/helper/model_helper.cc
  5. +2
    -1
      ge/common/helper/om_file_helper.cc
  6. +2
    -1
      ge/executor/ge_executor.cc
  7. +1
    -1
      ge/ge_inference.mk
  8. +1
    -1
      ge/ge_runner.mk
  9. +2
    -1
      ge/graph/build/graph_builder.cc
  10. +7
    -5
      ge/graph/build/memory/binary_block_mem_assigner.cc
  11. +325
    -191
      ge/graph/build/memory/block_mem_assigner.cc
  12. +30
    -6
      ge/graph/build/memory/block_mem_assigner.h
  13. +6
    -4
      ge/graph/build/memory/graph_mem_assigner.cc
  14. +0
    -1
      ge/graph/build/model_builder.cc
  15. +2
    -1
      ge/graph/build/stream_allocator.cc
  16. +2
    -1
      ge/graph/load/graph_loader.cc
  17. +13
    -11
      ge/graph/load/new_model_manager/davinci_model.cc
  18. +2
    -1
      ge/graph/load/new_model_manager/model_manager.cc
  19. +2
    -2
      ge/graph/load/new_model_manager/model_utils.cc
  20. +2
    -1
      ge/graph/load/new_model_manager/task_info/hccl_task_info.cc
  21. +2
    -2
      ge/graph/load/new_model_manager/task_info/kernel_task_info.cc
  22. +11
    -8
      ge/graph/manager/graph_manager.cc
  23. +2
    -1
      ge/graph/manager/util/hcom_util.cc
  24. +79
    -24
      ge/graph/passes/atomic_addr_clean_pass.cc
  25. +5
    -0
      ge/graph/passes/atomic_addr_clean_pass.h
  26. +23
    -5
      ge/graph/passes/attach_stream_label_pass.cc
  27. +3
    -1
      ge/graph/passes/attach_stream_label_pass.h
  28. +1
    -1
      ge/graph/passes/base_pass.cc
  29. +1
    -2
      ge/graph/passes/common_subexpression_elimination_pass.cc
  30. +0
    -55
      ge/graph/passes/const_pass.cc
  31. +0
    -29
      ge/graph/passes/const_pass.h
  32. +0
    -64
      ge/graph/passes/dimension_adjust_pass.cc
  33. +0
    -4
      ge/graph/passes/dimension_adjust_pass.h
  34. +7
    -41
      ge/graph/passes/enter_pass.cc
  35. +1
    -2
      ge/graph/passes/enter_pass.h
  36. +4
    -1
      ge/graph/passes/folding_pass.cc
  37. +10
    -0
      ge/graph/passes/merge_to_stream_merge_pass.cc
  38. +173
    -89
      ge/graph/passes/next_iteration_pass.cc
  39. +13
    -3
      ge/graph/passes/next_iteration_pass.h
  40. +4
    -4
      ge/graph/passes/subgraph_pass.cc
  41. +3
    -1
      ge/graph/passes/transop_breadth_fusion_pass.cc
  42. +2
    -1
      ge/graph/preprocess/graph_preprocess.cc
  43. +58
    -343
      ge/graph/preprocess/multi_batch_copy_graph.cc
  44. +1
    -15
      ge/graph/preprocess/multi_batch_copy_graph.h
  45. +6
    -2
      ge/host_kernels/ssd_prior_box_kernel.cc
  46. +4
    -2
      ge/hybrid/executor/hybrid_model_async_executor.cc
  47. +2
    -1
      ge/hybrid/executor/worker/shape_inference_engine.cc
  48. +12
    -6
      ge/hybrid/model/hybrid_model.cc
  49. +2
    -1
      ge/hybrid/model/hybrid_model.h
  50. +12
    -28
      ge/ir_build/ge_ir_build.cc
  51. +25
    -7
      ge/offline/CMakeLists.txt
  52. +7
    -6
      ge/offline/atc
  53. +2
    -2
      ge/offline/module.mk
  54. +20
    -0
      ge/offline/single_op_parser.cc
  55. +6
    -0
      ge/offline/single_op_parser.h
  56. +4
    -4
      ge/opskernel_manager/ops_kernel_manager.cc
  57. +2
    -1
      ge/session/omg.cc
  58. +2
    -1
      ge/single_op/single_op.cc
  59. +2
    -1
      ge/single_op/single_op_model.cc
  60. +2
    -1
      ge/single_op/task/tbe_task_builder.cc
  61. +1
    -1
      inc/framework/common/taskdown_common.h
  62. +1
    -1
      metadef

+ 2
- 1
ge/CMakeLists.txt View File

@@ -154,7 +154,6 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/control_trigger_pass.cc"
"graph/passes/dimension_adjust_pass.cc"
"graph/passes/dimension_compute_pass.cc"
@@ -202,6 +201,7 @@ set(TRAIN_SRC_LIST
"host_kernels/sub_kernel.cc"
"host_kernels/transdata_kernel.cc"
"host_kernels/unpack_kernel.cc"
"host_kernels/reformat_kernel.cc"
"graph/passes/folding_pass.cc"
"graph/passes/get_original_format_pass.cc"
"graph/passes/guarantee_const_pass.cc"
@@ -488,6 +488,7 @@ set(INFER_SRC_LIST
"host_kernels/slice_d_kernel.cc"
"host_kernels/dynamic_stitch_kernel.cc"
"host_kernels/identity_kernel.cc"
"host_kernels/reformat_kernel.cc"
"graph/passes/stop_gradient_pass.cc"
"graph/passes/prevent_gradient_pass.cc"
"graph/passes/identity_pass.cc"


+ 2
- 1
ge/common/debug/memory_dumper.cc View File

@@ -139,7 +139,8 @@ int MemoryDumper::OpenFile(const char *filename) {
GE_IF_BOOL_EXEC(
-1 != path_split_pos, string prefix_path = std::string(filename).substr(0, path_split_pos);
string last_path = std::string(filename).substr(path_split_pos, strlen(filename) - 1);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH, return kInvalidFd, "Prefix path is too long!");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(prefix_path.length() >= MMPA_MAX_PATH,
return kInvalidFd, "Prefix path is too long!");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmRealPath(prefix_path.c_str(), tmp_path, MMPA_MAX_PATH) != EN_OK, return kInvalidFd,
"Dir %s does not exit.", prefix_path.c_str());
real_path = std::string(tmp_path) + last_path;)


+ 7
- 1
ge/common/ge/plugin_manager.cc View File

@@ -123,7 +123,10 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec
if (handle == nullptr) {
const char *error = mmDlerror();
GE_IF_BOOL_EXEC(error == nullptr, error = "");
GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen %s!", error);
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
{"mmDlopen", "shared library path is " + FmtToStr(file_path_dlopen) + ". Errormessage" + FmtToStr(error)});
GELOGE(GE_PLGMGR_PATH_INVALID, "Failed to dlopen the shared library path[%s]. Errormessage[%s]!",
file_path_dlopen.c_str(), error);
continue;
}

@@ -132,6 +135,9 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec
for (const auto &func_name : func_check_list) {
auto real_fn = (void (*)())mmDlsym(handle, const_cast<char *>(func_name.c_str()));
if (real_fn == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"},
{"mmDlsym", FmtToStr(func_name) + " is skipped since function" +
FmtToStr(func_name) + " is not existed!"});
GELOGE(GE_PLGMGR_PATH_INVALID, "%s is skipped since function %s is not existed!", func_name.c_str(),
func_name.c_str());
is_valid = false;


+ 2
- 1
ge/common/helper/model_helper.cc View File

@@ -189,7 +189,8 @@ Status ModelHelper::SaveModelHeader(std::shared_ptr<OmFileSaveHelper> &om_file_s
err = memcpy_s(model_header.platform_version, PLATFORM_VERSION_LEN, platform_version.c_str(),
platform_version.size() + 1);
if (err != EOK) {
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "ModelHelper SaveModel failed while allocating memory for platform_version.");
GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION,
"ModelHelper SaveModel failed while allocating memory for platform_version.");
return ACL_ERROR_GE_MEMORY_ALLOCATION;
}
string version = reinterpret_cast<char *>(model_header.platform_version);


+ 2
- 1
ge/common/helper/om_file_helper.cc View File

@@ -180,7 +180,8 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
context_.partition_datas_.push_back(partition);

if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.",
GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID,
"The partition size %zu is greater than the model data size %u.",
partition.size + mem_offset, model_data_size);
return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID;
}


+ 2
- 1
ge/executor/ge_executor.cc View File

@@ -639,7 +639,8 @@ Status GeExecutor::UnloadModel(uint32_t model_id) {
return ACL_ERROR_GE_INTERNAL_ERROR;
}

std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model = ModelManager::GetInstance()->GetHybridModel(model_id);
std::shared_ptr<hybrid::HybridDavinciModel> hybrid_davinci_model =
ModelManager::GetInstance()->GetHybridModel(model_id);
if (hybrid_davinci_model != nullptr) {
uint64_t session_id = hybrid_davinci_model->GetSessionId();
VarManagerPool::Instance().RemoveVarManager(session_id);


+ 1
- 1
ge/ge_inference.mk View File

@@ -164,6 +164,7 @@ OMG_HOST_SRC_FILES := \
host_kernels/slice_d_kernel.cc \
host_kernels/dynamic_stitch_kernel.cc \
host_kernels/identity_kernel.cc \
host_kernels/reformat_kernel.cc \
graph/passes/stop_gradient_pass.cc \
graph/passes/prevent_gradient_pass.cc \
graph/passes/identity_pass.cc \
@@ -189,7 +190,6 @@ OMG_HOST_SRC_FILES := \
graph/passes/control_trigger_pass.cc \
graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \
graph/passes/const_pass.cc \
graph/passes/for_pass.cc \
graph/passes/enter_pass.cc \
graph/passes/assign_pass.cc \


+ 1
- 1
ge/ge_runner.mk View File

@@ -123,7 +123,6 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/const_pass.cc \
graph/passes/control_trigger_pass.cc \
graph/passes/dimension_adjust_pass.cc \
graph/passes/dimension_compute_pass.cc \
@@ -171,6 +170,7 @@ LIBGE_LOCAL_SRC_FILES := \
host_kernels/sub_kernel.cc \
host_kernels/transdata_kernel.cc \
host_kernels/unpack_kernel.cc \
host_kernels/reformat_kernel.cc \
graph/passes/folding_pass.cc \
graph/passes/get_original_format_pass.cc \
graph/passes/guarantee_const_pass.cc \


+ 2
- 1
ge/graph/build/graph_builder.cc View File

@@ -349,7 +349,8 @@ static Status GenerateTaskForConstant(const std::shared_ptr<ComputeGraph> &graph
GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str());
std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy";
if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) {
GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str());
GELOGE(FAILED, "Insert memcpy between %s and %s failed.",
in_node->GetName().c_str(), node->GetName().c_str());
return FAILED;
}
}


+ 7
- 5
ge/graph/build/memory/binary_block_mem_assigner.cc View File

@@ -21,8 +21,8 @@
namespace {
const uint32_t kRangeCeilInterval = 2;
const uint32_t kLogBase = 2;
const int64_t kLargeBlockSize = 8 * 1024 * 1024;
const int64_t kLargeBlockRangeSize = 10;
const int64_t kLargeBlockSize = 8 * 1024 * 1024; // 8M
const int64_t kLargeBlockRangeSize = 2;
} // namespace

namespace ge {
@@ -73,15 +73,17 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) {
GELOGE(FAILED, "dividend is 0!");
return FAILED;
}
// Memory size is 512 aligned, so it is not necessary to take less than 512
int64_t min_memory_size = (all_memory_size.back() > MEM_ALIGN_SIZE) ? MEM_ALIGN_SIZE : all_memory_size.front();
auto range_number = static_cast<size_t>(
ceil(log(all_memory_size.back() / static_cast<double>(all_memory_size.front())) / log(kLogBase)));
ceil(log(all_memory_size.back() / static_cast<double>(min_memory_size)) / log(kLogBase)));
range_number = (range_number == 0) ? 1 : range_number;
GELOGD("Range number: %zu", range_number);

vector<vector<int64_t>> ranges(range_number);
GE_CHK_BOOL_EXEC((range_number != 0), return PARAM_INVALID, "range_number can't be 0.");
size_t range_number_limit = all_memory_size.size() / range_number;
int64_t range_ceil = all_memory_size[0];
int64_t range_ceil = min_memory_size;
for (size_t i = 1; i <= range_number; i++) {
GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(static_cast<uint64_t>(range_ceil), kRangeCeilInterval),
GELOGE(FAILED, "Multiply result is out of range.");
@@ -114,7 +116,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) {
range_ceils.push_back(range.back());
}
}
GELOGD("Range ceils: %s", ToString(range_ceils).c_str());
GELOGI("Range ceils: %s", ToString(range_ceils).c_str());

return SUCCESS;
}


+ 325
- 191
ge/graph/build/memory/block_mem_assigner.cc View File

@@ -65,6 +65,98 @@ void AlignMemOffset(size_t &mem_align_size) {
mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE;
}

static bool CompareLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) {
auto left_node_op_desc = left.node->GetOpDesc();
auto right_node_op_desc = right.node->GetOpDesc();
if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)
&& (left_node_op_desc->GetId() < right_node_op_desc->GetId())) {
return true;
}
return false;
}

void GetLifeList(const MemoryBlock &block, std::vector<NodeTypeIndex> &life_list, bool child) {
for (auto &node : block.NodeTypeIndexList()) {
life_list.emplace_back(node);
}

if (child) {
for (auto child_block : block.ChildBlockList()) {
if (child_block == nullptr) {
continue;
}
if (block.stream_id_ != child_block->stream_id_ || !block.same_stream_ || !child_block->same_stream_) {
life_list.clear();
return;
}
GetLifeList(*child_block, life_list, child);
}
}
}

bool CrossLifeTime(const NodeTypeIndex &left, const NodeTypeIndex &right) {
if ((left.node == nullptr) || (right.node == nullptr)) {
return true;
}
auto left_node_op_desc = left.node->GetOpDesc();
auto right_node_op_desc = right.node->GetOpDesc();
if ((left_node_op_desc != nullptr) && (right_node_op_desc != nullptr)) {
if (left_node_op_desc->GetId() < right_node_op_desc->GetId()) {
if (left.life_time_end >= static_cast<size_t>(right_node_op_desc->GetId())) {
return true;
}
} else if (left_node_op_desc->GetId() == right_node_op_desc->GetId()) {
return true;
} else {
if (right.life_time_end >= static_cast<size_t>(left_node_op_desc->GetId())) {
return true;
}
}
}
return false;
}

///
/// When child block's life time are not cross with parent block, they can be reused(only same stream).
/// |-----------------------------parent block---------------------|
/// |------child block1--------------||------child block2------|
/// |--child block1-1-|
///
bool CanIntervalLifeReuse(MemoryBlock &parent_block, MemoryBlock &child_block) {
// judge by interval life time, only same stream can be judged by interval life time
if (parent_block.stream_id_ != child_block.stream_id_ || !parent_block.same_stream_ || !child_block.same_stream_
|| parent_block.NodeTypeIndexList().empty() || child_block.NodeTypeIndexList().empty()) {
return false;
}

// quick judge by front and back node
if (CrossLifeTime(parent_block.NodeTypeIndexList().front(), child_block.NodeTypeIndexList().front())) {
return false;
}
if (CrossLifeTime(parent_block.NodeTypeIndexList().back(), child_block.NodeTypeIndexList().back())) {
return false;
}

std::vector<NodeTypeIndex> life_list;
GetLifeList(parent_block, life_list, false);
GetLifeList(child_block, life_list, true);
if (life_list.empty()) {
return false;
}
std::sort(life_list.begin(), life_list.end(), CompareLifeTime);
size_t pre_life_end = 0;
for (auto &node : life_list) {
auto node_op_desc = node.node->GetOpDesc();
if (node_op_desc != nullptr && pre_life_end >= static_cast<size_t>(node_op_desc->GetId())) {
// life time cross
return false;
}
pre_life_end = node.life_time_end;
}
GELOGI("Block size[%zu, %zu] life time are not cross.", parent_block.Size(), child_block.Size());
return true;
}

void MemoryBlock::SetHeadOffset(size_t offset) {
head_offset_ = offset;
size_t child_offset = head_offset_;
@@ -125,20 +217,12 @@ size_t MemoryBlock::AlignSize() const {
return align_block_size;
}

bool MemoryBlock::IsSameLabel(std::string &first_batch_label) {
if (node_type_index_list_.empty()) {
bool MemoryBlock::IsSameBatchLabel() {
// only same batch label can reuse
if (batch_label_.empty() || node_type_index_list_.empty()) {
return false;
}

auto node_op_desc = node_type_index_list_[0].node->GetOpDesc();
if (node_op_desc == nullptr) {
return false;
}
// not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label);
if (first_batch_label.empty()) {
return false;
}
bool all_same_label = true;
for (size_t index = 1; index < node_type_index_list_.size(); ++index) {
if (node_type_index_list_[index].node == nullptr) {
@@ -147,8 +231,9 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) {
std::string batch_label;
auto index_op_desc = node_type_index_list_[index].node->GetOpDesc();
GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue);
// not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter
(void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (first_batch_label != batch_label) {
if (batch_label_ != batch_label) {
all_same_label = false;
break;
}
@@ -197,7 +282,7 @@ void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLi
}

void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) {
if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) {
if (CanNotLifeReuse(this) || CanNotLifeReuse(block) || (batch_label_ != block->batch_label_)) {
return;
}
if (block->continuous_block_) {
@@ -207,16 +292,27 @@ void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_
MemoryBlock *parent = nullptr;
MemoryBlock *child = nullptr;
// merge small block to large block
if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) {
if ((child_offset_ + block->AlignSize()) <= AlignSize()) {
parent = this;
child = block;
} else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) {
parent = block;
child = this;
// noalign size 802816 + 802816 = 1605632 can reuse
// after 32 align size 802848 + 802848 > 1605664 can't reuse
// after 512 align size 803328 + 803328 > 1606144 can't reuse
// so 803328 + 803328 = 1606144 + 512 can reuse
if ((child_offset_ + block->AlignSize()) <= (AlignSize() + MEM_ALIGN_SIZE)) {
parent = this;
child = block;
} else if ((block->child_offset_ + AlignSize()) <= (block->AlignSize() + MEM_ALIGN_SIZE)) {
parent = block;
child = this;
}

if ((parent != nullptr) && (child != nullptr)) {
// Different streams must use stream dependency to judge the life cycle
// In case same stream if it has child block, can judge all the child block's life time in CanIntervalLifeReuse
bool can_block_life_reuse = (child->child_blocks_.empty()
&& (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()));
if (!can_block_life_reuse && !CanIntervalLifeReuse(*parent, *child)) {
return;
}
}
if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) {

parent->child_blocks_.emplace_back(child);
parent->child_offset_ += child->AlignSize();
child->deleted_block_ = true;
@@ -261,6 +357,7 @@ size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &tota
void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id,
std::map<int64_t, size_t> &depend_stream_life, DependStreamLife &total_node_depend_stream_life) {
GE_CHECK_NOTNULL_EXEC(node, return);
GE_CHECK_NOTNULL_EXEC(org_node, return);
auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(node_desc, return);
auto node_id = node_desc->GetId();
@@ -415,12 +512,60 @@ BlockMemAssigner::~BlockMemAssigner() {
}
}

void GetMaxBatchAllMemorySize(std::map<std::string, vector<int64_t>> &batch_all_memory_size,
std::map<std::string, int64_t> batch_total_size, vector<int64_t> &all_memory_size,
std::string &max_batch_label) {
// use max batch all memory size for reuse range
int64_t max_batch_size = 0;
for (const auto &it : batch_total_size) {
GELOGI("Batch[%s] total memory size[%ld]", it.first.c_str(), it.second);
// no batch label
if (it.first.empty()) {
continue;
}
if (it.second > max_batch_size) {
max_batch_size = it.second;
max_batch_label = it.first;
}
}
GELOGI("Max batch[%s] total memory size[%ld]", max_batch_label.c_str(), max_batch_size);

for (const auto &it : batch_all_memory_size) {
if (it.first.empty() || (it.first == max_batch_label)) {
all_memory_size.insert(all_memory_size.end(), it.second.begin(), it.second.end());
}
}
// all_memory_size can't be empty
if (all_memory_size.empty()) {
all_memory_size.emplace_back(MEM_ALIGN_SIZE);
}
sort(all_memory_size.begin(), all_memory_size.end());
GELOGD("All memory size: %s", ToString(all_memory_size).c_str());

for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) {
if (*iter == 0) {
iter = all_memory_size.erase(iter);
} else {
++iter;
}
}
}

void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
vector<int64_t> temp;
std::map<std::string, vector<int64_t>> batch_all_memory_size;
std::map<std::string, int64_t> batch_total_size;
for (const NodePtr &n : compute_graph_->GetAllNodes()) {
auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);

if (CheckIsZeroMemNodeType(node_op_desc->GetType())) {
continue;
}

std::string batch_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);

if (node_op_desc->GetType() == ATOMICADDRCLEAN) {
atomic_addr_clean_id_ = node_op_desc->GetId();
}
@@ -434,9 +579,14 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
if (!reuse_input) {
int64_t size = 0;
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed"));
if (anchor_to_symbol_.empty()) {
all_memory_size.emplace_back(size);
batch_all_memory_size[batch_label].emplace_back(size);
if (batch_total_size.find(batch_label) == batch_total_size.end()) {
batch_total_size[batch_label] = size;
} else {
batch_total_size[batch_label] += size;
}

if (!anchor_to_symbol_.empty()) {
auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString());
if (iter1 == anchor_to_symbol_.end()) {
continue;
@@ -452,23 +602,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
}
}
temp.clear();
GetNodeWorkSpaceSize(n, temp);
all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end());
}
for (const auto &pair : symbol_size_) {
all_memory_size.emplace_back(pair.second);
}
sort(all_memory_size.begin(), all_memory_size.end());
GELOGD("All memory size: %s", ToString(all_memory_size).c_str());

for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) {
if (*iter == 0) {
iter = all_memory_size.erase(iter);
} else {
++iter;
}
GetNodeWorkSpaceSize(n, temp, batch_total_size[batch_label]);
batch_all_memory_size[batch_label].insert(batch_all_memory_size[batch_label].end(), temp.begin(), temp.end());
}

GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_);
GetMaxBatchAllMemorySize(batch_all_memory_size, batch_total_size, all_memory_size, max_batch_label_);
InitReuseFlag();
PrintSymbolMap();
}
@@ -529,16 +667,6 @@ bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const Me
bool can_reuse = false;
if (reusable_block.Size() == block_size) {
can_reuse = true;
} else {
string key = std::to_string(reusable_block.Size());
key += "_" + std::to_string(reusable_block.stream_id_);
key += "_" + std::to_string(reusable_block.memory_type_);
auto it = reusable_block_counts.find(key);
GE_IF_BOOL_EXEC((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) &&
(reusable_block.Size() > block_size),
can_reuse = true;
GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu",
reusable_block.Size(), block_size););
}
return can_reuse;
}
@@ -860,17 +988,26 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null.");
auto node_op_desc = n->GetOpDesc();
GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr);
std::string batch_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty() || (batch_label == max_batch_label_)) {
size_t align_size = real_size;
AlignMemOffset(align_size);
theory_memory_size_ += align_size;
if (theory_memory_size_ > theory_min_memory_size_) {
theory_min_memory_size_ = theory_memory_size_;
}
}

bool is_reuse_memory = false;
string ge_disable_reuse_mem_env = "0";
(void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env);
if (ge_disable_reuse_mem_env != "1") {
if (ge_disable_reuse_mem_env_ != "1") {
bool reuse_mem_flag = (mem_type == kOutput) ? IsPreReuse(n, out_index) :
!((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]);
is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) &&
!node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem;
auto stream_id = node_op_desc->GetStreamId();
if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) {
bool do_reuse = is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty();
if (do_reuse) {
auto stream_id = node_op_desc->GetStreamId();
for (auto it = reusable_blocks_[memory_type][stream_id].rbegin();
it != reusable_blocks_[memory_type][stream_id].rend(); ++it) {
MemoryBlock *reusable_block = *it;
@@ -879,15 +1016,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
GELOGI("Unreusable block.");
continue;
}
std::string batch_label;
if (reusable_block->IsSameLabel(batch_label)) {
std::string op_label;
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label);
if (batch_label != op_label) {
GELOGI("label diff, op name %s", node_op_desc->GetName().c_str());
continue;
}
}
GE_IF_BOOL_EXEC(reusable_block->batch_label_ != batch_label, continue);

// A node can reuse blocks of the same stream and preorder streams
if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) {
@@ -914,10 +1043,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size,
// Data and netoutput need zero copy block
block->is_zero_copy_ = IsZeroCopyBlock(n, continuous);

block->Init(real_size, mem_type, n, out_index, no_align_size);
block->Init(real_size, mem_type, n, out_index, no_align_size, node_op_desc->GetStreamId());
block->stream_id_ = node_op_desc->GetStreamId();
block->ref_count_++;
block->continuous_block_ = continuous;
block->batch_label_ = batch_label;
if (mem_type == kOutput) {
auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString());
if (iter != anchor_to_symbol_.end()) {
@@ -945,6 +1075,11 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
return nullptr;
}

if (CheckIsZeroMemNodeType(n->GetType())) {
zero_memory_list_.emplace_back(n, kOutput, index);
continue;
}

int64_t size = 0;
if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) {
GELOGI("Get size failed");
@@ -957,9 +1092,7 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
// only apply total size in first block
if (index != 0) {
zero_memory_list_.emplace_back(n, kOutput, index);
}

if (index == 0) {
} else {
NodeIndexIO node_index_io(n, index, kOut);
auto iter = anchor_to_symbol_.find(node_index_io.ToString());
if (iter != anchor_to_symbol_.end()) {
@@ -972,6 +1105,10 @@ MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vec
}
}

if (total_size == 0) {
return nullptr;
}

auto block_size = GetBlockSize(total_size, ranges);
GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(),
total_size, block_size);
@@ -1119,15 +1256,28 @@ bool IsKnownSubgraphData(const NodePtr &node) {
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
}

void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory) {
void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory,
bool same_stream) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null.");
GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory");
GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory");
--to_release->ref_count_;
if (!same_stream) {
to_release->same_stream_ = false;
}
if (to_release->ref_count_ == 0) {
to_release->SetLifeTimeEnd(life_time_);
reusable_memory.emplace_back(to_release);
AddReusableBlockCount(*to_release, reusable_block_counts_);
if (to_release->reuse_mem_ && !to_release->RealSizeList().empty()) {
if (to_release->batch_label_.empty() || (to_release->batch_label_ == max_batch_label_)) {
size_t align_size = to_release->RealSizeList().back();
AlignMemOffset(align_size);
theory_memory_size_ -= align_size;
}
}
if (to_release->same_stream_) {
to_release->SetLifeTimeEnd(life_time_);
reusable_memory.emplace_back(to_release);
AddReusableBlockCount(*to_release, reusable_block_counts_);
}
}
}

@@ -1167,10 +1317,9 @@ void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_map<string, vec
node_type_indexs.back().node->GetName().c_str());

if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) &&
(node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx())) &&
(node->GetOpDesc()->GetStreamId() == block->stream_id_)) {
ReleaseMemory(block, reusable_memory);
if (block->ref_count_ == 0) {
(node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx()))) {
ReleaseMemory(block, reusable_memory, (node->GetOpDesc()->GetStreamId() == block->stream_id_));
if (block->ref_count_ == 0 && block->same_stream_) {
SetLastUsedInputMemAttr(node, in_anchor->GetIdx());
}
}
@@ -1267,7 +1416,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector
bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType()));
if (!no_need_assign_memory) {
out_node_set_continuous_input =
IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index, no_need_assign_memory, reset_zero_copy_flag);
IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index,
no_need_assign_memory, reset_zero_copy_flag);
GE_IF_BOOL_EXEC(!no_need_assign_memory,
no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input););
}
@@ -1328,7 +1478,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
iter->second[stream_id].clear();
}
vector<int64_t> temp;
GetNodeWorkSpaceSize(n, temp);
int64_t tatal_size = 0;
GetNodeWorkSpaceSize(n, temp, tatal_size);
vector<int64_t> workspace_bytes;
vector<int64_t> tvm_workspace_memory_type;
bool has_tvm_workspace_mem_type_attr =
@@ -1349,7 +1500,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
bool workspace_skip_flag = false;
if (has_tvm_workspace_mem_type_attr && tvm_workspace_memory_type[i] == RT_MEMORY_L1) {
GELOGI(
"fusion: node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]",
"fusion:node[%s]workspace index[%zu] is not hbm type, add to zero_memory_list, workspace memory type [%ld]",
node_op_desc->GetName().c_str(), i, tvm_workspace_memory_type[i]);
workspace_skip_flag = true;
}
@@ -1380,9 +1531,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
(void)mem_block; // Fix warning
}

bool merge_dynamic_batch = false;
GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks());
GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size()));
GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), ReuseBlocksByLifeTime(ranges.size()));
AssignContinuousBlocks();
ResizeMemoryBlocks();

@@ -1402,92 +1551,19 @@ void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_f
}
}

void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory) {
void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory,
int64_t &total_size) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null.");
vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes();

GELOGD("node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size());
for (int64_t byte_size : workspace_byte_nums) {
workspace_memory.emplace_back(byte_size);
total_size += byte_size;
GELOGD("push back size:%ld", byte_size);
}
}

// descending order
static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) {
if (left == nullptr || right == nullptr) {
return false;
}
auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end());
if (left_max_size != left->RealSizeList().end()) {
auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end());
if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) {
return true;
}
}
return false;
}

void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &src) {
for (size_t i = 0; i < dest.size(); ++i) {
if (i >= src.size()) {
return;
}
if (dest[i] != nullptr && src[i] != nullptr) {
if (!dest[i]->reuse_mem_ || !src[i]->reuse_mem_) {
GELOGD("Diff batch's workspace can't be reused, i: %zu, dest[i]: %s, stream: %ld, src[i]: %s, stream: %ld.",
i, dest[i]->String().c_str(), dest[i]->stream_id_, src[i]->String().c_str(), src[i]->stream_id_);
continue;
}
for (auto &symbol : src[i]->SymbolList()) {
dest[i]->AddSymbol(symbol);
}
for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) {
dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j],
src[i]->RealSizeList()[j],
src[i]->NoAlignSizeList()[j]);
src[i]->deleted_block_ = true;
}
}
}
}

bool BlockMemAssigner::MergeDynamicBatchBlocks() {
bool merged = false;
std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
for (auto block : memory_blocks_) {
if (block == nullptr) {
continue;
}
std::string batch_label;
if (block->IsSameLabel(batch_label)) {
dynamic_batch_blocks[batch_label].emplace_back(block);
}
}

auto it = dynamic_batch_blocks.begin();
auto it_max = it;

// find max block counts
for (; it != dynamic_batch_blocks.end(); ++it) {
if (it->second.size() > it_max->second.size()) {
it_max = it;
}
std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize);
}
if (it_max != dynamic_batch_blocks.end()) {
GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size());
}
for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) {
if (it != it_max) {
GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str());
MergeBlocks(it_max->second, it->second);
merged = true;
}
}
return merged;
}

// asending order
static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) {
if (left == nullptr || right == nullptr) {
@@ -1597,38 +1673,93 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) {
}
}

void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) {
if (block.memory_type_ == RT_MEMORY_HBM) {
if (block.first_continuous_block_) {
mem_offset += MEM_ALIGN_SIZE;
}
block.Resize();
block.SetHeadOffset(mem_offset);
mem_offset += block.Size();
block.SetTailOffset(mem_offset - 1);
} else if (block.memory_type_ == RT_MEMORY_P2P_DDR) {
if (block.first_continuous_block_) {
p2p_mem_offset += MEM_ALIGN_SIZE;
}
block.Resize();
block.SetHeadOffset(p2p_mem_offset);
p2p_mem_offset += block.Size();
block.SetTailOffset(p2p_mem_offset - 1);
}
}

bool DynamicBatchBlockReuse(MemoryBlock &block) {
return (block.IsSameBatchLabel() && block.reuse_mem_);
}

///
/// @ingroup domi_omg
/// @brief traverse memory size, resize, calculate offset
/// @brief get max batch memory size, others reuse this block memory
/// @param [in&out] memory_blocks_ memory block, after calculating offset
/// |-dynamic batch block batch1|
/// |-dynamic batch block batch2----|
/// |-dynamic batch block batch3--|
///
void BlockMemAssigner::ResizeMemoryBlocks() {
for (auto &memory_block : memory_blocks_) {
if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) {
void BlockMemAssigner::ResizeDynamicBatchBlocks() {
std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
for (auto block : memory_blocks_) {
if (block == nullptr) {
continue;
}
if (memory_block->memory_type_ == RT_MEMORY_HBM) {
if (memory_block->first_continuous_block_) {
mem_offset_ += MEM_ALIGN_SIZE;
}
// when memory is not reuseable, it can't be reused by different branch
if (DynamicBatchBlockReuse(*block)) {
dynamic_batch_blocks[block->batch_label_].emplace_back(block);
}
}

memory_block->Resize();
memory_block->SetHeadOffset(mem_offset_);
mem_offset_ += memory_block->Size();
memory_block->SetTailOffset(mem_offset_ - 1);
} else if (memory_block->memory_type_ == RT_MEMORY_P2P_DDR) {
if (memory_block->first_continuous_block_) {
p2p_mem_offset_ += MEM_ALIGN_SIZE;
size_t max_mem_offset = mem_offset_;
size_t max_p2p_mem_offset = p2p_mem_offset_;
for (auto &batch_blocks : dynamic_batch_blocks) {
size_t mem_offset = mem_offset_;
size_t p2p_mem_offset = p2p_mem_offset_;
for (auto block : batch_blocks.second) {
if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) {
continue;
}
AddBlockMemOffset(mem_offset, p2p_mem_offset, *block);
}
if (mem_offset > max_mem_offset) {
max_mem_offset = mem_offset;
}
if (p2p_mem_offset > max_p2p_mem_offset) {
max_p2p_mem_offset = p2p_mem_offset;
}
GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset);
}
mem_offset_ = max_mem_offset;
p2p_mem_offset_ = max_p2p_mem_offset;
}

memory_block->Resize();
memory_block->SetHeadOffset(p2p_mem_offset_);
p2p_mem_offset_ += memory_block->Size();
memory_block->SetTailOffset(p2p_mem_offset_ - 1);
///
/// @ingroup domi_omg
/// @brief traverse memory size, resize, calculate offset
/// @param [in&out] memory_blocks_ memory block, after calculating offset
/// |-not dynamic batch block-||-dynamic batch block batch1| |-zero copy block-|
/// |-not dynamic batch block-||-dynamic batch block batch2----||-zero copy block-|
/// |-not dynamic batch block-||-dynamic batch block batch3--| |-zero copy block-|
///
void BlockMemAssigner::ResizeMemoryBlocks() {
for (auto &memory_block : memory_blocks_) {
if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_
|| DynamicBatchBlockReuse(*memory_block)) {
continue;
}

AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block);
}
GELOGD("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu.",
mem_offset_, p2p_mem_offset_);
ResizeDynamicBatchBlocks();
GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu,"
"theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_);
}

///
@@ -1641,7 +1772,7 @@ void BlockMemAssigner::ResizeMemoryBlocks() {
/// @return Status result
///
void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block,
size_t real_size, size_t no_align_size, bool child_block) {
size_t real_size, size_t no_align_size, int32_t child_block_level) {
ge::OpDescPtr op_desc = node_type.node->GetOpDesc();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null.");
string graph_name = node_type.node->GetOwnerComputeGraph()->GetName();
@@ -1689,14 +1820,15 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block,
}
op_desc->SetWorkspace(workspace_list);
}
GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]"
" noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(),
GELOGI("[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu] noalignsize[%zu] "
"life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d:%d] isref[%d] batch[%s]", graph_name.c_str(),
op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(),
block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, block->reuse_mem_,
block->continuous_block_, block->deleted_block_, node_type.ref_input);
block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block_level, block->reuse_mem_,
block->continuous_block_, block->is_zero_copy_, block->same_stream_, node_type.ref_input,
block->batch_label_.c_str());
}

void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) {
void SetBlockOpMemOffset(MemoryBlock *block, int32_t child_block_level) {
if (block == nullptr) {
return;
}
@@ -1709,9 +1841,14 @@ void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) {
real_size = block->RealSizeList()[index];
no_align_size = block->NoAlignSizeList()[index];
}
SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block);
SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block_level);
index++;
}

child_block_level++;
for (MemoryBlock *child_block : block->ChildBlockList()) {
SetBlockOpMemOffset(child_block, child_block_level);
}
}

void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
@@ -1724,16 +1861,13 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
continue;
}

SetBlockOpMemOffset(memory_block, false);
for (MemoryBlock *child_block : memory_block->ChildBlockList()) {
SetBlockOpMemOffset(child_block, true);
}
SetBlockOpMemOffset(memory_block, 0);
}

if (!is_zero_copy) {
for (const NodeTypeIndex &node_type_index : zero_memory_list_) {
MemoryBlock block(0, 0);
SetOffsetSize(node_type_index, &block, 0, 0, false);
SetOffsetSize(node_type_index, &block, 0, 0, 0);
}
}
}


+ 30
- 6
ge/graph/build/memory/block_mem_assigner.h View File

@@ -65,6 +65,7 @@ class MemoryBlock {
stream_id_(stream_id),
deleted_block_(false),
reuse_mem_(reuse_mem),
same_stream_(true),
input_index_(0),
continuous_block_(false),
first_continuous_block_(false),
@@ -85,10 +86,14 @@ class MemoryBlock {
symbol_list_.clear();
}

void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) {
void Init(size_t real_size, OpMemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size,
int64_t stream_id) {
real_size_list_.emplace_back(real_size);
no_align_size_list_.emplace_back(no_align_size);
node_type_index_list_.emplace_back(node, type, out_index, false);
if (stream_id != stream_id_) {
same_stream_ = false;
}
}
size_t Size() const { return block_size_; }

@@ -106,6 +111,12 @@ class MemoryBlock {
node_type_index_list_.emplace_back(node_type_index);
real_size_list_.emplace_back(real_size);
no_align_size_list_.emplace_back(no_align_size);
if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) {
auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId();
if (stream_id != stream_id_) {
same_stream_ = false;
}
}
}

void AddSymbol(const std::string &symbol) {
@@ -122,7 +133,7 @@ class MemoryBlock {

std::string String();

bool IsSameLabel(std::string &first_batch_label);
bool IsSameBatchLabel();

void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life);

@@ -142,6 +153,7 @@ class MemoryBlock {
int64_t stream_id_;
bool deleted_block_;
bool reuse_mem_;
bool same_stream_;
uint32_t input_index_;
bool continuous_block_;
bool first_continuous_block_;
@@ -149,6 +161,7 @@ class MemoryBlock {
bool is_zero_copy_;
std::map<int64_t, size_t> depend_stream_life_;
int64_t memory_type_;
std::string batch_label_;
private:
size_t block_size_;
std::vector<size_t> real_size_list_;
@@ -209,7 +222,7 @@ class BlockMemAssigner : public MemAssigner {

void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size);

void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory);
void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory, int64_t &total_size);

///
/// @ingroup GE
@@ -353,7 +366,7 @@ class BlockMemAssigner : public MemAssigner {
/// @return void
/// @author
///
void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory);
void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, bool same_stream = true);

///
/// @ingroup GE
@@ -379,11 +392,11 @@ class BlockMemAssigner : public MemAssigner {

///
/// @ingroup GE
/// @brief Merge memory blocks between different batchs
/// @brief Resize memory blocks for each batchs
/// @return merge or not
/// @author
///
bool MergeDynamicBatchBlocks();
void ResizeDynamicBatchBlocks();

void AssignContinuousBlocks();

@@ -436,6 +449,17 @@ class BlockMemAssigner : public MemAssigner {

int64_t atomic_addr_clean_id_ = 0;

size_t theory_min_memory_size_ = 0;

size_t theory_memory_size_ = 0;

std::string max_batch_label_;

///
/// @ [stream1][nodeid]
/// @[nodeid] [stream2][nodeid]
/// @ [stream2][nodeid]
///
DependStreamLife total_node_depend_stream_life_;
};
} // namespace ge


+ 6
- 4
ge/graph/build/memory/graph_mem_assigner.cc View File

@@ -419,7 +419,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node,
GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1),
std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) +
" requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) +
" requires continuous output. There may be conflict between the two. This node is not supported now.";
" requires continuous output. There may be conflict between the two." +
"This node is not supported now.";
GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str());
return PARAM_INVALID;);

@@ -429,7 +430,8 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node,
GE_IF_BOOL_EXEC(is_peer_reference,
std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) +
" requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) +
" requires continuous output. There may be conflict between the two. This node is not supported now.";
" requires continuous output. There may be conflict between the two." +
"This node is not supported now.";
GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str());
return PARAM_INVALID;);

@@ -1646,9 +1648,9 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &node, const ve
}
string atomic_mem_size_str = ss.str();

GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]",
GELOGI("[IMAS]SetAtomicCleanAttr : Set %s atomic_node name[%s] output[0] offset to [%s] streamid[%ld] size[%s]",
node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(),
atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId());
atomic_mem_start_str.c_str(), node->GetOpDesc()->GetStreamId(), atomic_mem_size_str.c_str());
}
return SUCCESS;
}


+ 0
- 1
ge/graph/build/model_builder.cc View File

@@ -224,7 +224,6 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_
GeTensorDesc &tensor_desc = weight->MutableTensorDesc();
size_t output_size = weight->GetData().size();
TensorUtils::SetDataOffset(tensor_desc, mem_offset);
GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size);
mem_offset += output_size;
}
return SUCCESS;


+ 2
- 1
ge/graph/build/stream_allocator.cc View File

@@ -49,7 +49,8 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &
}

bool IsHcclOp(const string &op_type) {
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER,
ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
return hccl_op_types.find(op_type) != hccl_op_types.end();
}
} // namespace


+ 2
- 1
ge/graph/load/graph_loader.cc View File

@@ -283,7 +283,8 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn
std::vector<GeTensorDesc> &output_desc) {
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc);
Status ret = model_manager->ExecuteModel(model_id, stream, async_mode,
input_data, input_desc, output_data, output_desc);
if (ret != SUCCESS) {
GELOGE(ret, "Execute model failed, model_id:%u.", model_id);
return ret;


+ 13
- 11
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -83,7 +83,7 @@ const uint32_t kAddrLen = sizeof(void *);
const int kDecimal = 10;
const int kBytes = 8;
const uint32_t kDataMemAlignSizeCompare = 64;
const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024;
const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; // 2M
const uint32_t kDumpFlagOfL1Fusion = 0;
const char *const kDefaultBatchLable = "Batch_default";
const char *const kGetDynamicDimsName = "ascend_mbatch_get_dynamic_dims_node";
@@ -330,8 +330,8 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {
GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size);
return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED;
}
GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id,
mem_base_, data_size);
GEEVENT("[IMAS]InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]",
runtime_param_.graph_id, mem_base_, data_size);

if (!is_inner_weight_base_) {
weights_mem_base_ = mem_base_;
@@ -1543,7 +1543,8 @@ Status DavinciModel::LoadWithQueue() {
}

if (output_queue_ids_.size() != new_output_data_info_.size()) {
GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu",
GELOGE(ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID,
"Output queue ids not match model: output_queue=%zu output_data=%zu",
output_queue_ids_.size(), new_output_data_info_.size());
return ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID;
}
@@ -2202,7 +2203,7 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data
void *mem_addr = data.second.GetBasicAddr();
void *data_buf_addr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(data_buf.data));
uint64_t data_buf_length = data_buf.length;
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]",
GELOGI("CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%lu] datasize[%lu]",
runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length);
GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind));
}
@@ -3391,14 +3392,14 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64
///
Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) {
if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) {
GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed.");
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update input data to model failed.");
return ACL_ERROR_GE_PARAM_INVALID;
}

if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) !=
SUCCESS) {
GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed.");
return PARAM_INVALID;
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[ZCPY] Update output data to model failed.");
return ACL_ERROR_GE_PARAM_INVALID;
}

for (ZeroCopyTask &task : zero_copy_tasks_) {
@@ -3861,7 +3862,8 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa
if (!is_async_mode_) {
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_START));
ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ACL_ERROR_GE_INTERNAL_ERROR,
"Copy Output data to user failed.");
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), SetProfileTime(MODEL_AFTER_PROC_END));
}

@@ -4061,7 +4063,7 @@ void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) {
data_dumper_.SetDeviceId(device_id);

// set loop count addr
auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void * {
auto get_var_addr = [](const OpDescPtr &op, const RuntimeParam &runtime_param) -> void *{
if (op != nullptr) {
auto v_output_size = ModelUtils::GetOutputSize(op);
auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param, op);


+ 2
- 1
ge/graph/load/new_model_manager/model_manager.cc View File

@@ -1254,7 +1254,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy
}

std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id);
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, ACL_ERROR_GE_EXEC_MODEL_ID_INVALID,
"Invalid model id %u, check weather model has been loaded or not.", model_id);

if (davinci_model->NeedDestroyAicpuKernel()) {
GELOGI("Start to destroy specified aicpu kernel.");


+ 2
- 2
ge/graph/load/new_model_manager/model_utils.cc View File

@@ -61,7 +61,7 @@ vector<int64_t> ModelUtils::GetInputSize(ConstOpDescPtr op_desc) {
GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i);
continue);

GELOGI("[IMAS]GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
GELOGI("GetInputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
v_input_size.push_back(tensor_size);
}

@@ -96,7 +96,7 @@ vector<int64_t> ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) {
GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i);
continue);

GELOGI("[IMAS]GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
GELOGI("GetOutputSize op: %s, index: %zu, size:%ld", op_desc->GetName().c_str(), i, tensor_size);
v_output_size.push_back(tensor_size);
}



+ 2
- 1
ge/graph/load/new_model_manager/task_info/hccl_task_info.cc View File

@@ -281,7 +281,8 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc,
kernel_hccl_infos[i].inputDataAddr = input_data_addr;
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) {
kernel_hccl_infos[i].outputDataAddr = output_data_addr;
} else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) {
} else if (hccl_type == HCOMALLREDUCE ||
hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE || hccl_type == HCOMREDUCE) {
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type),
"davinci_model: GetHcomOperationType fail!");
kernel_hccl_infos[i].outputDataAddr = output_data_addr;


+ 2
- 2
ge/graph/load/new_model_manager/task_info/kernel_task_info.cc View File

@@ -1172,8 +1172,8 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u
}
ccStatus_t cc_ret;
std::string update_kernel_args = "ccUpdateKernelArgs";
auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t, uint64_t, void *, uint64_t,
void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str()));
auto cceUpdateKernelArgs = (ccStatus_t(*)(ccOpContext &, uint64_t, uint64_t,
uint64_t, void *, uint64_t, void *))mmDlsym(handle, const_cast<char *>(update_kernel_args.c_str()));
if (cceUpdateKernelArgs == nullptr) {
GELOGE(FAILED, "Failed to invoke function ccUpdateKernelArgs");
if (mmDlclose(handle) != 0) {


+ 11
- 8
ge/graph/manager/graph_manager.cc View File

@@ -56,7 +56,6 @@
#include "graph/passes/cond_remove_pass.h"
#include "graph/passes/constant_folding_pass.h"
#include "graph/passes/constant_fuse_same_pass.h"
#include "graph/passes/const_pass.cc"
#include "graph/passes/control_trigger_pass.h"
#include "graph/passes/ctrl_edge_transfer_pass.h"
#include "graph/passes/dimension_adjust_pass.h"
@@ -550,8 +549,13 @@ Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_gr
if (!op_compile_strategy.empty()) {
(void) AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy);
}
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this,
compute_graph->GetGraphID(), subgraph, compute_graph, session_id, GetThreadLocalContext());
std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads,
this,
compute_graph->GetGraphID(),
subgraph,
compute_graph,
session_id,
GetThreadLocalContext());
if (!f.valid()) {
GELOGE(FAILED, "Future is invalid");
return FAILED;
@@ -2138,7 +2142,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
TransposeTransDataPass transpose_transdata_pass;
TransOpSymmetryEliminationPass symmetry_elimination_pass;
DimensionComputePass dimension_compute_pass;
ConstPass const_pass;
names_to_passes.emplace_back("EnterPass", &enter_pass);
names_to_passes.emplace_back("AddNPass", &addn_pass);
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination);
@@ -2152,7 +2155,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass);
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass);
names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass);
names_to_passes.emplace_back("ConstPass", &const_pass);
GE_TIMESTAMP_START(names_to_passes);
ret = GEPass(compute_graph).Run(names_to_passes);
GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2");
@@ -2193,8 +2195,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass",
new (std::nothrow) VariableRefUselessControlOutDeletePass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::CommonSubexpressionEliminationPass",
new (std::nothrow) CommonSubexpressionEliminationPass));
if (options_.train_graph_flag) {
// Priority: The GlobalStepInsertPass should work before graph partitioner.
// Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory
@@ -2471,7 +2471,6 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
GetContext().SetSessionId(session_id);
GetThreadLocalContext() = ge_context;
graph_manager->UpdateLocalOmgContext(root_graph_id);

ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph();
const std::string &engine_name = sub_graph_info_ptr->GetEngineName();
GELOGD("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu",
@@ -2479,6 +2478,10 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager
pthread_self());
GE_DUMP(compute_graph_tmp, "OptimizeSubGraphBefore");
GE_CHECK_NOTNULL(compute_graph_tmp);
if (!AttrUtils::SetInt(*compute_graph_tmp, ATTR_NAME_ROOT_GRAPH_ID, root_graph_id)) {
GELOGE(FAILED, "Failed to set attr ATTR_NAME_ROOT_GRAPH_ID for subgraph, graph_id: %u.", root_graph_id);
return FAILED;
}
compute_graph_tmp->SetSessionID(session_id);
Status ret = graph_manager->GetCompilerStages(root_graph_id).optimizer.OptimizeSubGraph(compute_graph_tmp,
compute_graph,


+ 2
- 1
ge/graph/manager/util/hcom_util.cc View File

@@ -263,7 +263,8 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro
Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc,
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) {
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) {
if (op_desc->GetType() == HCOMBROADCAST ||
op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) {
GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str());
int64_t root_id = 0;
Status dmrt = GetHcclRootId(op_desc, root_id);


+ 79
- 24
ge/graph/passes/atomic_addr_clean_pass.cc View File

@@ -74,10 +74,87 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) {
return SUCCESS;
}

// just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input
bool AtomicAddrCleanPass::CheckAtomicFromOpsKernel(const NodePtr &node) {
// 1.Check if isAtomic attrs exist for HCOM
std::shared_ptr<GELib> instance_ptr = GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGW("GELib not initialized, atomic from ops kernel judge false, node_name: %s", node->GetName().c_str());
return false;
}

OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(node->GetType());
for (const auto &op_info : op_info_vec) {
if (op_info.isAtomic) {
// check peer input is DATA
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor->GetPeerOutAnchor() != nullptr &&
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) {
auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode();
if (peer_in_node->GetType() == DATA) {
GELOGI("Recognized atomic op %s from %s engine and input is DATA.", node->GetName().c_str(), op_info.engine.c_str());
return false;
}
}
}
GELOGI("Recognized atomic op %s from %s engine.", node->GetName().c_str(), op_info.engine.c_str());
hcom_node_vec_.push_back(node);
return true;
}
}
return false;
}

bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index) {
auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index);
if (out_data_anchor == nullptr) {
return false;
}

for (auto input_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto output_node = input_anchor->GetOwnerNode();
// just hccl may mark atomic from ops kernel now, and hccl's atomic if for all input
// hccl's attr ATOMIC_ATTR_INPUT_INDEX mark on CalcOpRunningParam, can't be get here
if (CheckAtomicFromOpsKernel(output_node)) {
return true;
}
}
return false;
}

bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) {
OpDescPtr op_desc = node->GetOpDesc();
std::map<string, std::map<int, int>> node_workspace_offset;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) {
std::vector<int64_t> atomic_output_index;
(void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index);
bool is_all_output_peer_also_atomic = true;
for (const auto &output_index : atomic_output_index) {
if (!IsOutputIndexPeerInputAtomic(node, output_index)) {
is_all_output_peer_also_atomic = false;
break;
}
}
if (is_all_output_peer_also_atomic) {
GELOGI("all out peer node input atomic, skip this out atomic process, node name: %s", node->GetName().c_str());
return true;
}
}
return false;
}

Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector<NodePtr> &atomic_node_vec) {
// Loop graph , insert clean node follow atomic node
int index = 0;
for (const auto &node : atomic_node_vec) {
if (CheckSkipInsertInLoopGraph(node)) {
continue;
}

// Insert atomic clean op
NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph);
if (clean_addr_node == nullptr) {
@@ -249,32 +326,10 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) {
return false;
}
// 1.Check if isAtomic attrs exist for HCOM
std::shared_ptr<GELib> instance_ptr = GELib::GetInstance();
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
GELOGW("GELib not initialized");
return false;
if (CheckAtomicFromOpsKernel(node)) {
return true;
}

OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
for (const auto &op_info : op_info_vec) {
if (op_info.isAtomic) {
GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str());
// check peer input is DATA
for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
if (in_data_anchor->GetPeerOutAnchor() != nullptr &&
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) {
auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode();
if (peer_in_node->GetType() == DATA) {
GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str());
return false;
}
}
}
hcom_node_vec_.push_back(node);
return true;
}
}
// 2.Check atomic attr in node
std::map<string, std::map<int, int>> node_workspace_offset;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);


+ 5
- 0
ge/graph/passes/atomic_addr_clean_pass.h View File

@@ -84,6 +84,11 @@ class AtomicAddrCleanPass : public GraphPass {
Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector<NodePtr> &atomic_node_vec,
std::vector<NodePtr> &common_atomic_nodes);

bool CheckAtomicFromOpsKernel(const NodePtr &node);

bool IsOutputIndexPeerInputAtomic(const NodePtr &node, int64_t output_index);

bool CheckSkipInsertInLoopGraph(const NodePtr &node);

vector<NodePtr> hcom_node_vec_;
bool is_loop_graph_ = false;


+ 23
- 5
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -18,8 +18,6 @@
#include "ge/ge_api_types.h"
#include "graph/common/omg_util.h"

using std::string;

namespace ge {
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {
GELOGD("AttachStreamLabelPass Enter.");
@@ -189,10 +187,21 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
}

std::stack<NodePtr> enter_nodes;
std::string batch_label;
for (const auto &enter_node : pair.second) {
enter_nodes.emplace(enter_node);
std::string tmp_label;
(void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty()) {
if (batch_label.empty()) {
batch_label = tmp_label;
} else if (batch_label != tmp_label) {
GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str());
return FAILED;
}
}
}
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) {
if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) {
GELOGE(FAILED, "Update stream_label for loop_branch failed.");
return FAILED;
}
@@ -217,7 +226,10 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
}

for (const auto &enter_node : enter_nodes) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
GE_CHECK_NOTNULL(enter_node->GetOpDesc());
if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
}
return SUCCESS;
}
@@ -229,7 +241,8 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
/// @param [in] batch_label
/// @return Status
///
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) {
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label) {
std::stack<NodePtr> nodes(enter_nodes);
NodePtr cur_node = nullptr;
while (!nodes.empty()) {
@@ -238,6 +251,11 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_
for (const NodePtr &out_node : cur_node->GetOutAllNodes()) {
OpDescPtr out_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_desc);
std::string tmp_label;
(void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty() && (tmp_label != batch_label)) {
continue;
}
std::string out_type = out_desc->GetType();
bool need_skip =
out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) ||


+ 3
- 1
ge/graph/passes/attach_stream_label_pass.h View File

@@ -58,9 +58,11 @@ class AttachStreamLabelPass : public GraphPass {
/// @brief Update stream_label for loop_branch
/// @param [in] enter_nodes
/// @param [in] stream_label
/// @param [in] batch_label
/// @return Status
///
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label);
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label);

///
/// @brief Update stream_label start with enter nodes


+ 1
- 1
ge/graph/passes/base_pass.cc View File

@@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen) || node_to_re_pass->GetType() == ENTER) {
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str());
nodes_re_pass.insert(node_to_re_pass);
} else {


+ 1
- 2
ge/graph/passes/common_subexpression_elimination_pass.cc View File

@@ -58,8 +58,7 @@ std::string GetCseKey(const NodePtr &node) {
/// To avoid delete wrong nodes(e.g. stateful nodes),
/// only nodes have folding kernel will be considered for the CSE process
bool IsNodeSupportCse(const NodePtr &node) {
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node)) || node->GetType() == CONSTANT ||
node->GetType() == CONSTANTOP) {
if (HostCpuEngine::CheckSupported(NodeUtils::GetNodeType(*node))) {
return true;
}
return folding_pass::GetKernelByType(node) != nullptr;


+ 0
- 55
ge/graph/passes/const_pass.cc View File

@@ -1,55 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/passes/const_pass.h"

#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"

namespace ge {
Status ConstPass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);

if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) {
return SUCCESS;
}
GELOGD("ConstPass running, node: %s.", node->GetName().c_str());

// const has no control input
if (node->GetInControlNodes().empty()) {
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr) {
GELOGD("Node: %s unlink all out control edge.", node->GetName().c_str());
out_ctrl_anchor->UnlinkAll();
}

if (node->GetOutAllNodes().empty()) {
// it is an isolated const, just remove it.
GELOGD("Delete isolated const: %s.", node->GetName().c_str());
auto graph = node->GetOwnerComputeGraph();
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove const %s failed.", node->GetName().c_str());
return FAILED;
}
AddNodeDeleted(node);
}
}

return SUCCESS;
}
} // namespace ge

+ 0
- 29
ge/graph/passes/const_pass.h View File

@@ -1,29 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_GRAPH_PASSES_CONST_PASS_H_
#define GE_GRAPH_PASSES_CONST_PASS_H_

#include "graph/passes/base_pass.h"

namespace ge {
class ConstPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
};
} // namespace ge

#endif // GE_GRAPH_PASSES_CONST_PASS_H_

+ 0
- 64
ge/graph/passes/dimension_adjust_pass.cc View File

@@ -80,71 +80,7 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
}
}

ret = DealWithInNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str());
return ret;
}

std::vector<int> data_relink_io_map = {kDataInputIndex};
return IsolateAndDeleteNode(node, data_relink_io_map);
}

Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
auto graph = node->GetOwnerComputeGraph();
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
if (in_data_anchor == nullptr) {
continue;
}
auto in_node_anchor = in_data_anchor->GetPeerOutAnchor();
if (in_node_anchor == nullptr) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if (in_node->GetType() == SWITCHN) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx());
auto identity =
AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph);
GE_CHECK_NOTNULL(identity);
GELOGI("Create new identity node[%s] success.", identity->GetName().c_str());
GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0)))
GE_CHECK_NOTNULL(identity->GetOutControlAnchor());
if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) {
continue;
}
GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor()))
}
}

return SUCCESS;
}

NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor,
ComputeGraphPtr &graph) {
if (graph == nullptr) {
GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node.");
return nullptr;
}

OpDescPtr desc = MakeShared<OpDesc>("", "");
if (desc == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to create op desc.");
return nullptr;
}

desc->SetName(name);
desc->SetType(IDENTITY);
auto ret = desc->AddInputDesc(tensor);
auto ret2 = desc->AddOutputDesc(tensor);
if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity.");
return nullptr;
}

return graph->AddNodeFront(desc);
}
} // namespace ge

+ 0
- 4
ge/graph/passes/dimension_adjust_pass.h View File

@@ -34,10 +34,6 @@ namespace ge {
class DimensionAdjustPass : public BaseNodePass {
public:
Status Run(ge::NodePtr &node) override;

private:
Status DealWithInNodes(ge::NodePtr &node);
NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph);
};
} // namespace ge



+ 7
- 41
ge/graph/passes/enter_pass.cc View File

@@ -23,7 +23,6 @@

namespace {
const size_t kOutNodesNum = 1;
const size_t kInCtrlNodesNum = 1;
}

namespace ge {
@@ -56,7 +55,6 @@ Status EnterPass::Run(NodePtr &node) {
if (out_ctrl_node == nullptr) {
continue;
}
GELOGD("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(),
out_ctrl_node->GetName().c_str());
@@ -64,12 +62,8 @@ Status EnterPass::Run(NodePtr &node) {
}
}
} else {
if (OptimizeEnterWithOnlyOutData(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str());
return FAILED;
}
if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) {
GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str());
if (OptimizeEnter(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str());
return FAILED;
}
}
@@ -78,7 +72,7 @@ Status EnterPass::Run(NodePtr &node) {
return SUCCESS;
}

Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node) {
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS;
}
@@ -89,45 +83,17 @@ Status EnterPass::OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node)
}

GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor);
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor))
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
}
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node);

return SUCCESS;
}

Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) {
auto out_ctrl_nodes = node->GetOutControlNodes();
if (out_ctrl_nodes.empty()) {
return SUCCESS;
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);

for (auto &out_ctrl_node : out_ctrl_nodes) {
GE_CHECK_NOTNULL(out_ctrl_node);
if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) {
continue;
}
auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes();
if (in_ctrl_nodes.size() != kInCtrlNodesNum) {
continue;
}
GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor()))
auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes();
for (auto &out_node_of_const : out_nodes_of_const) {
if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) {
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor()))
}
}
}
return SUCCESS;
}
} // namespace ge

+ 1
- 2
ge/graph/passes/enter_pass.h View File

@@ -25,8 +25,7 @@ class EnterPass : public BaseNodePass {
Status Run(NodePtr &node) override;

private:
Status OptimizeEnterWithOnlyOutData(NodePtr &node, NodePtr &in_node);
Status UnlinkCtrlEdgeBeforeConst(NodePtr &node);
Status OptimizeEnter(NodePtr &node, NodePtr &in_node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_ENTER_PASS_H_

+ 4
- 1
ge/graph/passes/folding_pass.cc View File

@@ -173,7 +173,10 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) {
if (in_node == nullptr) {
continue;
}
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto ret = in_node_anchor->Unlink(in_data_anchor);
if (ret != SUCCESS) {


+ 10
- 0
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,6 +89,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed");
}

if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
string batch_label;
(void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (!batch_label.empty()) {
auto stream_merge_desc = stream_merge->GetOpDesc();
GE_CHECK_NOTNULL(stream_merge_desc);
(void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label);
}
}

return AddActiveNodes(graph, stream_merge);
}



+ 173
- 89
ge/graph/passes/next_iteration_pass.cc View File

@@ -19,8 +19,6 @@
#include "common/ge/ge_util.h"
#include "graph/common/omg_util.h"

using std::string;

namespace ge {
Status NextIterationPass::Run(ComputeGraphPtr graph) {
GELOGD("NextIterationPass Enter");
@@ -37,6 +35,10 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) {
return INTERNAL_ERROR;
}
}
if (GroupWithNoBatch(graph) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr.");
return INTERNAL_ERROR;
}

if (FindWhileGroups() != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Find while groups failed.");
@@ -71,22 +73,75 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
return FAILED;
}

string batch_label;
if (ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
frame_name += batch_label;
std::string batch_label;
(void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (batch_label.empty()) {
auto frame_iter = frame_enter_map_.find(frame_name);
if (frame_iter == frame_enter_map_.end()) {
std::vector<NodePtr> enter_nodes;
enter_nodes.emplace_back(enter_node);
frame_enter_map_[frame_name] = enter_nodes;
} else {
frame_iter->second.emplace_back(enter_node);
}
return SUCCESS;
}

auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
auto group_iter = loop_group_map_.find(frame_name);
if (group_iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
loop_group_map_[frame_name] = loop_group;
loop_group_map_[frame_name][batch_label] = loop_group;
} else {
iter->second->enter_nodes.emplace_back(enter_node);
auto batch_iter = group_iter->second.find(batch_label);
if (batch_iter == group_iter->second.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes.emplace_back(enter_node);
group_iter->second[batch_label] = loop_group;
} else {
batch_iter->second->enter_nodes.emplace_back(enter_node);
}
}

return SUCCESS;
}

///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) {
if (frame_enter_map_.empty()) {
GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str());
return SUCCESS;
}
for (const auto &item : frame_enter_map_) {
const std::string &frame_name = item.first;
auto iter = loop_group_map_.find(frame_name);
if (iter == loop_group_map_.end()) {
LoopCondGroupPtr loop_group = MakeShared<LoopCondGroup>();
if (loop_group == nullptr) {
GELOGE(FAILED, "MakeShared for LoopCondGroup failed.");
return FAILED;
}
loop_group->enter_nodes = item.second;
loop_group_map_[frame_name][""] = loop_group;
} else {
for (auto &batch_item : iter->second) {
for (const auto &enter_node : item.second) {
batch_item.second->enter_nodes.emplace_back(enter_node);
}
}
}
}

return SUCCESS;
@@ -99,39 +154,55 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) {
Status NextIterationPass::FindWhileGroups() {
for (const auto &loop_group_iter : loop_group_map_) {
const std::string &frame_name = loop_group_iter.first;
for (const auto &enter_node : loop_group_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
const string &type = out_node->GetType();
if ((type != MERGE) && (type != REFMERGE)) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s", frame_name.c_str());
return INTERNAL_ERROR;
}
loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str());
return INTERNAL_ERROR;
}
if (loop_group_iter.second->loop_cond == nullptr) {
loop_group_iter.second->loop_cond = loop_cond;
} else if (loop_group_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str());
return FAILED;
for (const auto &batch_iter : loop_group_iter.second) {
const std::string &batch_label = batch_iter.first;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
for (const auto &out_node : enter_node->GetOutAllNodes()) {
GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(),
frame_name.c_str(), batch_label.c_str());
if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) {
continue;
}
std::string tmp_label;
GE_CHECK_NOTNULL(out_node->GetOpDesc());
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}

NodePtr next_node = nullptr;
if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node));

NodePtr switch_node = nullptr;
if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s",
out_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (switch_node == nullptr) {
continue;
}

NodePtr loop_cond = nullptr;
if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) {
GELOGE(INTERNAL_ERROR,
"Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s",
switch_node->GetName().c_str());
return INTERNAL_ERROR;
}
if (batch_iter.second->loop_cond == nullptr) {
batch_iter.second->loop_cond = loop_cond;
} else if (batch_iter.second->loop_cond != loop_cond) {
GELOGE(FAILED, "Multi LoopCond nodes exist.");
return FAILED;
}
}
}
}
@@ -152,17 +223,19 @@ bool NextIterationPass::VerifyWhileGroup() {
GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty.");
return false;
}
if (loop_group_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false;
}

for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
for (const auto &batch_iter : loop_group_iter.second) {
if (batch_iter.second->loop_cond == nullptr) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str());
return false;
}

for (const auto &pair_iter : batch_iter.second->merge_next_pairs) {
if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) {
GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.",
frame_name.c_str());
return false;
}
}
}
}

@@ -176,53 +249,56 @@ bool NextIterationPass::VerifyWhileGroup() {
///
Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
for (const auto &loop_cond_iter : loop_group_map_) {
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR;
}

for (const auto &enter_node : loop_cond_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge from %s to %s failed.", enter_node->GetName().c_str(),
enter_active->GetName().c_str());
for (const auto &batch_iter : loop_cond_iter.second) {
const std::string &cond_name = batch_iter.second->loop_cond->GetName();
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str());

// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge
NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE);
NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE);
if ((enter_active == nullptr) || (next_active == nullptr)) {
GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str());
return INTERNAL_ERROR;
}
}

for (const auto &pair : loop_cond_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &enter_node : batch_iter.second->enter_nodes) {
// Enter --> Active
if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}
}

// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
for (const auto &pair : batch_iter.second->merge_next_pairs) {
NodePtr merge_node = pair.first;
NodePtr next_node = pair.second;
// Active --> Merge
if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) !=
GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// NextIteration --> Active
if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Add control edge failed.");
return INTERNAL_ERROR;
}

// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
return INTERNAL_ERROR;
}
}

// break link between NextIteration and Merge
if (BreakNextIteration(next_node, merge_node) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Break NextIteration failed");
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}
}

if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) ||
(SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed.");
return INTERNAL_ERROR;
}
}

return SUCCESS;
@@ -289,11 +365,12 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &
/// @param [in] node
/// @param [in] target_type
/// @param [in] is_input
/// @param [in] batch_label
/// @param [out] target_node
/// @return Status
///
Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
NodePtr &target_node) {
const std::string &batch_label, NodePtr &target_node) {
if (node == nullptr) {
GELOGE(PARAM_INVALID, "node is null.");
return PARAM_INVALID;
@@ -310,6 +387,12 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
}

for (const auto &tmp_node : nodes) {
std::string tmp_label;
(void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label));
if (need_skip) {
continue;
}
const std::string type = tmp_node->GetType();
if ((target_type == LOOPCOND) && (type == target_type)) {
target_node = tmp_node;
@@ -332,6 +415,7 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string
/// @return SUCCESS
///
Status NextIterationPass::ClearStatus() {
frame_enter_map_.clear();
loop_group_map_.clear();
return SUCCESS;
}


+ 13
- 3
ge/graph/passes/next_iteration_pass.h View File

@@ -47,6 +47,13 @@ class NextIterationPass : public GraphPass {
Status GroupEnterNode(const NodePtr &enter_node);

///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status GroupWithNoBatch(const ComputeGraphPtr &graph);

///
/// @brief Find while groups
/// @return Status
///
@@ -90,10 +97,13 @@ class NextIterationPass : public GraphPass {
/// @param [out] target_node
/// @return Status
///
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node);
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
const std::string &batch_label, NodePtr &target_node);

// map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;
// map<frame_name, vector<enter_node>>
std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_;
// map<frame_name, map<batch_label, LoopCondGroup>>
std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_

+ 4
- 4
ge/graph/passes/subgraph_pass.cc View File

@@ -149,10 +149,10 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node
// 5. While->NetOutput in known subgraph
std::string op_type;
bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) ||
IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) ||
((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) ||
(!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) &&
(kWhileOpTypes.count(in_node->GetType()) != 0));
IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) ||
((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) ||
(!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) &&
(kWhileOpTypes.count(in_node->GetType()) != 0));
if (insert_flag) {
GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str());
std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy";


+ 3
- 1
ge/graph/passes/transop_breadth_fusion_pass.cc View File

@@ -70,8 +70,10 @@ std::string TransOpBreadthFusionPass::GetNodeId(const int anchor_index, const No
trans_data_type = true;
trans_format = true;
trans_shape = true;
} else if (node->GetType() == RESHAPE) {
} else if (node->GetType() == RESHAPE || node->GetType() == EXPANDDIMS || node->GetType() == SQUEEZE) {
trans_shape = true;
} else if (node->GetType() == REFORMAT) {
trans_format = true;
}

id << node->GetType() << '-' << anchor_index;


+ 2
- 1
ge/graph/preprocess/graph_preprocess.cc View File

@@ -1621,7 +1621,8 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) {

for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) {
if (desc.GetShape().GetDim(i) < 0) {
std::string situation = "data dim[" + std::to_string(i) + "][" + std::to_string(desc.GetShape().GetDim(i)) + "]" ;
std::string situation = "data dim[" + std::to_string(i) + "][" +
std::to_string(desc.GetShape().GetDim(i)) + "]" ;
std::string reason = "it need >= 0";
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason});
GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i,


+ 58
- 343
ge/graph/preprocess/multi_batch_copy_graph.cc View File

@@ -44,8 +44,6 @@
using std::set;
using std::string;
using std::vector;
using std::map;
using std::queue;

namespace ge {
namespace multibatch {
@@ -59,15 +57,10 @@ const int kDataInIndex = 0;
const int kMergeDataOutIndex = 0;
const int kStaticOutput = -1;
const int kDivisionConst = 2;
const int32_t kOneInDataNode = 1;
const int32_t kFindNoMatch = 0;


inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }

inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); }
const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER});

inline bool IsGetNextType(const NodePtr &node) {
std::string original_type;
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
@@ -225,6 +218,12 @@ Status MultiBatchGraphCopyer::CopyGraph() {
return ret;
}

ret = InsertIdentityAfterSwitchN();
if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
return INTERNAL_ERROR;
}

GELOGI("Begin to remove useless nodes by prune pass after copy process");
PrunePass prune_pass;
ret = prune_pass.Run(graph_);
@@ -241,18 +240,6 @@ Status MultiBatchGraphCopyer::Init() {
return ret;
}

ret = RelinkConstCtrlEdge();
if (ret != SUCCESS) {
GELOGE(FAILED, "Relink const's control edge failed.");
return FAILED;
}

ret = ExtractUnchangedStructureOutofCycle();
if (ret != SUCCESS) {
GELOGE(FAILED, "Extract unchanged structure out of cycle failed.");
return FAILED;
}

for (auto &node : graph_->GetAllNodes()) {
origin_all_nodes_.emplace_back(node);
if (IsDataLikeType(node->GetType())) {
@@ -265,281 +252,6 @@ Status MultiBatchGraphCopyer::Init() {
return SUCCESS;
}

Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
if (node->GetOutDataNodes().empty()) {
continue;
}
if (!node->GetInControlNodes().empty()) {
auto in_ctrl_nodes = node->GetInControlNodes();
auto out_nodes = node->GetOutAllNodes();
bool has_merge = false;
for (const auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) {
has_merge = true;
break;
}
}
if (has_merge) {
continue;
}
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
in_ctrl_anchor->UnlinkAll();
for (auto &in_ctrl_node : in_ctrl_nodes) {
auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node);
for (auto &out_node : out_nodes) {
if (IsEnterType(out_node->GetType())) {
continue;
}
if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) {
GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor()))
}
}
}
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr) {
out_ctrl_anchor->UnlinkAll();
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() {
map<string, vector<NodePtr>> frame_enter;
if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) {
GELOGE(FAILED, "Get enter nodes grouped by frame_name failed.");
return FAILED;
}

queue<NodePtr> nodes_to_extract;
if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) {
GELOGE(FAILED, "Get nodes needed to extract failed.");
return FAILED;
}

while (!nodes_to_extract.empty()) {
auto node = nodes_to_extract.front();
nodes_to_extract.pop();
OpDescPtr enter_desc = nullptr;
if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) {
GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str());
return FAILED;
}
set<NodePtr> out_nodes;
if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str());
return FAILED;
}

if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) {
GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str());
return FAILED;
}

for (auto &out_node : out_nodes) {
GE_CHECK_NOTNULL(out_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) {
nodes_to_extract.push(out_node);
}
}
}

if (DeleteEnterWithoutDataOut() != SUCCESS) {
GELOGE(FAILED, "Delete enter node without out data nodes failed.");
return FAILED;
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
string frame_name;
if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str());
return FAILED;
}
frame_enter[frame_name].emplace_back(node);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter,
queue<NodePtr> &nodes_to_extract) {
for (const auto &one_group : frame_enter) {
auto enters = one_group.second;
for (const auto &enter : enters) {
auto out_data_nodes = enter->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
GE_CHECK_NOTNULL(out_data_node);
if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) {
nodes_to_extract.push(out_data_node);
}
}
}
}

return SUCCESS;
}

bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) {
auto out_data_nodes = node->GetOutDataNodes();
for (const auto &out_data_node : out_data_nodes) {
if (out_data_node == nullptr) {
return false;
}

if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) {
return false;
}
}

auto in_data_nodes = node->GetInDataNodes();
if (in_data_nodes.size() == kOneInDataNode) {
return true;
}

for (const auto &in_data_node : in_data_nodes) {
if (in_data_node == nullptr) {
return false;
}
if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) {
return false;
}
}

return true;
}

Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) {
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor);
auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode();
if (IsEnterType(peer_in_data_node->GetType())) {
GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor))
GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str());
auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors();
for (auto &enter_in_data_anchor : enter_in_data_anchors) {
auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter);
if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) {
continue;
}
GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor))
GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
}
enter_desc = peer_in_data_node->GetOpDesc();
GE_CHECK_NOTNULL(enter_desc);
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr &copy_desc, set<NodePtr> &out_nodes) {
if (copy_desc == nullptr) {
return SUCCESS;
}
map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes;
auto out_data_anchors = node->GetAllOutDataAnchors();
for (auto &out_data_anchor : out_data_anchors) {
auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors();
for (auto peer_in_data_anchor : peer_in_data_anchors) {
GE_CHECK_NOTNULL(peer_in_data_anchor);
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
out_nodes.emplace(peer_in_data_node);
outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node));
}
}

int32_t i = 0;
auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(node_desc);
// Insert one enter node after node's per out data anchor
for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) {
string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++);
GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str());
auto enter_desc = AttrUtils::CopyOpDesc(copy_desc);
enter_desc->SetName(name);
GE_CHK_STATUS_RET(
enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
GE_CHK_STATUS_RET(
enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
auto enter_node = graph_->AddNode(enter_desc);
GE_CHECK_NOTNULL(enter_node);
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex)))
GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex));
for (auto &inanchor_node : outanchor_inanchors_nodes.second) {
GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first))
GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first))
GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(),
inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(),
inanchor_node.second->GetName().c_str());
}
}

return SUCCESS;
}

// Move node's in control edges to out data nodes
Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) {
auto in_ctrl_anchor = node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors();
for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor))
GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
node->GetName().c_str());
for (auto &out_node : out_nodes) {
auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node);
if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) {
GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node))
GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
out_node->GetName().c_str());
}
}
}

return SUCCESS;
}

Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() {
for (auto &node : graph_->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
if (IsEnterType(node->GetType())) {
auto out_nodes = node->GetOutAllNodes();
if (out_nodes.empty()) {
GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str());
GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}))
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node))
}
}
}

return SUCCESS;
}

void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(),
@@ -585,9 +297,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
LabelStatusForGetNextSink(data);
}
}

map<string, vector<NodePtr>> frame_enters;
InitStatus(frame_enters);
bool changed = true;
// If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
while (changed) {
@@ -597,13 +306,12 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
if (iter != origin_nodes_status_.end()) {
continue;
}
for (auto &in_node : node->GetInDataNodes()) {
if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) {
if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
origin_nodes_status_[node.get()] == kNodeInBatchBranch;
ResetEnterStatus(frame_enters, node);
changed = true;
}
for (auto &in_node : node->GetInAllNodes()) {
bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
if (is_in_batch) {
origin_nodes_status_[node.get()] = kNodeInBatchBranch;
changed = true;
break;
}
}
@@ -612,45 +320,6 @@ Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
return SUCCESS;
}

void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) {
for (const auto &node : origin_all_nodes_) {
if (!IsEnterType(node->GetType())) {
continue;
}
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
continue;
}
string frame_name;
if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
frame_enters[frame_name].emplace_back(node);
}
}

for (const auto &data : origin_data_nodes_) {
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
if (!IsAllDimsPositive(data_shape.GetDims())) {
origin_nodes_status_[data.get()] = kNodeInBatchBranch;
}
}
}

void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) {
if (!IsEnterType(node->GetType())) {
return;
}

for (const auto &frame_enter : frame_enters) {
auto &enters = frame_enter.second;
if (std::find(enters.begin(), enters.end(), node) != enters.end()) {
for (const auto &enter : enters) {
origin_nodes_status_[enter.get()] = kNodeInBatchBranch;
}
break;
}
}
}

Status MultiBatchGraphCopyer::LabelStatus() {
if (LabelInBatchBranchStatus() != SUCCESS) {
GELOGE(PARAM_INVALID, "Failed to label no in batch branch");
@@ -1691,6 +1360,52 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
return SUCCESS;
}

Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
for (auto &node : graph_->GetAllNodes()) {
if (node->GetType() != SWITCHN) {
continue;
}
auto switchn_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(switchn_desc);
size_t i = 0;
for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
auto out_node = in_data_anchor->GetOwnerNode();
auto op_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
continue;
}

auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
GE_CHECK_NOTNULL(identity_desc);

string batch_label;
if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
return FAILED;
}
}

auto data_desc = switchn_desc->GetOutputDesc(i);
i++;
GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));

auto identity_node = graph_->AddNode(identity_desc);
GE_CHECK_NOTNULL(identity_node);
GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
}
}
}

return SUCCESS;
}

Status ProcessMultiBatch(ComputeGraphPtr &graph) {
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE");
if (multi_batch_with_case != nullptr) {


+ 1
- 15
ge/graph/preprocess/multi_batch_copy_graph.h View File

@@ -18,7 +18,6 @@
#include <map>
#include <queue>
#include <vector>
#include <set>

#include "external/ge/ge_api_error_codes.h"

@@ -65,26 +64,12 @@ class MultiBatchGraphCopyer {
private:
Status Init();
Status CheckArguments();
Status RelinkConstCtrlEdge();

Status ExtractUnchangedStructureOutofCycle();
Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter);
Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter,
std::queue<NodePtr> &nodes_to_extract);
bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node);
Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc);
Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes);
Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes);
Status DeleteEnterWithoutDataOut();

// label status for origin_all_nodes_
Status LabelStatus();
Status LabelInBatchBranchStatus();
void LabelStatusForData(const NodePtr &data);
void LabelStatusForGetNextSink(const NodePtr &data);
void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters);
void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node);

// add nodes functions
Status CreateNewNodes();

@@ -96,6 +81,7 @@ class MultiBatchGraphCopyer {
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);

Status InsertIdentityAfterSwitchN();
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);



+ 6
- 2
ge/host_kernels/ssd_prior_box_kernel.cc View File

@@ -180,8 +180,12 @@ Status SsdPriorboxKernel::SetVariance(const vector<float> &variance, const int d
return SUCCESS;
}

Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint32_t aspect_ratios_size, uint32_t min_sizes_size, uint32_t max_sizes_size,
int layer_width, int layer_height, int &num_priors,
Status SsdPriorboxKernel::GetNumPriorAndDimSize(uint32_t aspect_ratios_size,
uint32_t min_sizes_size,
uint32_t max_sizes_size,
int layer_width,
int layer_height,
int &num_priors,
int &dim_size) const {
if (ge::CheckUint32MulOverflow(min_sizes_size, aspect_ratios_size) != SUCCESS) {
return PARAM_INVALID;


+ 4
- 2
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -379,11 +379,13 @@ Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs,
}
if (output_real_size > 0) {
if (outputs[i].length < static_cast<uint64_t>(output_real_size)) {
GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by user should be greater than or equal to the real size of output[%ld]",
GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by "
"user should be greater than or equal to the real size of output[%ld]",
i, outputs[i].length, output_real_size);
return FAILED;
}
GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size, RT_MEMCPY_DEVICE_TO_DEVICE));
GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length,
args.outputs[i].GetData(), output_real_size, RT_MEMCPY_DEVICE_TO_DEVICE));
}
outputs[i].length = output_real_size;
}


+ 2
- 1
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -62,7 +62,8 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
{
std::lock_guard<std::mutex> lk(mu_);
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start");
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true), "Invoke InferShapeAndType failed.");
GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndTypeForRunning(node_item.node, true),
"Invoke InferShapeAndType failed.");
RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End");
}
// Check again to make sure shape is valid after shape inference


+ 12
- 6
ge/hybrid/model/hybrid_model.cc View File

@@ -176,7 +176,8 @@ Status HybridModel::GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_de
return SUCCESS;
}

void HybridModel::SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, std::vector<std::pair<int64_t,int64_t>> &shape_ranges,
void HybridModel::SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims,
std::vector<std::pair<int64_t, int64_t>> &shape_ranges,
InputOutputDescInfo &input) {
for (auto model_input_dim : model_input_dims) {
input.shape_info.dims.push_back(model_input_dim);
@@ -245,7 +246,8 @@ Status HybridModel::GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, st
return SUCCESS;
}

void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDescInfo &output_desc_info, uint32_t &format_result) {
void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc,
InputOutputDescInfo &output_desc_info, uint32_t &format_result) {
GE_IF_BOOL_EXEC(output_desc == nullptr, GELOGE(FAILED, "output desc ptr is nullptr"); return );
Format format = output_desc->GetFormat();
GeShape shape = output_desc->GetShape();
@@ -283,7 +285,8 @@ void HybridModel::CreateOutput(ConstGeTensorDescPtr &output_desc, InputOutputDes

Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats) {
std::vector<ConstGeTensorDescPtr> output_desc_list;
GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed"); // output_desc_list contains vaild input desc
// output_desc_list contains vaild input desc
GE_CHK_STATUS_RET(root_graph_item_->GetOutputDescList(output_desc_list), "get output desc info failed");

vector<std::string> out_node_names;
(void)ge::AttrUtils::GetListStr(ge_root_model_->GetRootGraph(), ATTR_MODEL_OUT_NODES_NAME, out_node_names);
@@ -293,7 +296,8 @@ Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc,
GE_CHECK_NOTNULL(op_desc);

auto out_size = static_cast<uint32_t>(op_desc->GetInputsSize());
GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(), FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size());
GE_CHK_BOOL_RET_STATUS(out_size == output_desc_list.size(),
FAILED, "output size[%u] not match output_desc_list size[%zu]", out_size, output_desc_list.size());

for (uint32_t index = 0; index < out_size; ++index) {
string output_name;
@@ -301,9 +305,11 @@ Status HybridModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc,
std::vector<int64_t> src_index = op_desc->GetSrcIndex();
if (out_size == out_node_names.size()) {
bool contains_colon = out_node_names[index].find(":") != std::string::npos;
output_name = contains_colon ? out_node_names[index] : out_node_names[index] + ":" + std::to_string(src_index[index]);
output_name = contains_colon ? out_node_names[index] : out_node_names[index] +
":" + std::to_string(src_index[index]);
} else {
output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]);
output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] +
"_" + std::to_string(src_index[index]);
}

InputOutputDescInfo output_desc_info;


+ 2
- 1
ge/hybrid/model/hybrid_model.h View File

@@ -104,7 +104,8 @@ class HybridModel {

void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; }

void SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims, std::vector<std::pair<int64_t, int64_t>> &shape_ranges,
void SetInputDimsAndShapeRangesInfo(const vector<int64_t> &model_input_dims,
std::vector<std::pair<int64_t, int64_t>> &shape_ranges,
InputOutputDescInfo &input);

private:


+ 12
- 28
ge/ir_build/ge_ir_build.cc View File

@@ -36,7 +36,6 @@
#include "model/ge_model.h"
#include "graph/shape_refiner.h"
#include "graph/opsproto_manager.h"
#include "graph/utils/type_utils.h"

using std::string;
using namespace std;
@@ -50,11 +49,8 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default";
const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize";
const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0";
const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false";

const std::string kInputShape = "input_shape";
const std::string kInputFormat = "input_format";
const std::string kReUseMemEnable = "1";
const std::string kReUseMemDisEnable = "0";
} // namespace

static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) {
@@ -232,12 +228,12 @@ class Impl {
graphStatus CheckOptions(const std::map<std::string, std::string> &options);
graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs);
graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape);
graphStatus UpdateDataOpAttr(const Graph &graph);
graphStatus Init(const Graph &graph, const std::map<std::string, std::string> &options);
graphStatus BuildModel(const Graph &graph, const std::map<std::string, std::string> &options,
ModelBufferData &ge_models);
graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);
graphStatus UpdateDataOpAttr(const Graph &graph);
void SetRtSocVersion();
void UpdateThreadContext();
void LoadOpsProto();
@@ -429,6 +425,7 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri

// for IR builder.Only support om mode, so here fixed;
options_.insert(std::pair<string, string>(string(IR_OPTION_MODE), to_string(0)));
options_.insert(std::pair<string, string>(string(IR_OPTION_TARGET), "mini"));
options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));
options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0)));
@@ -468,52 +465,39 @@ void Impl::UpdateThreadContext() {
graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) {
auto compute_graph = ge::GraphUtils::GetComputeGraph(graph);
GE_CHECK_NOTNULL(compute_graph);
int64_t index = 0;
for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(input_node);
ge::OpDescPtr op = input_node->GetOpDesc();
GE_CHECK_NOTNULL(op);
if (op->GetType() == DATA) {
(void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++);
GELOGD("Data op inputDesc size: %zu", op->GetAllInputsDesc().size());
auto tensor = op->MutableInputDesc(0);
GE_CHECK_NOTNULL(tensor);
ge::GeTensorDesc tensor = op->GetInputDesc(0);
string data_op_name = op->GetName();
GELOGD("Data op name: %s", data_op_name.c_str());
ge::GeShape data_shape;
auto iter = omg_context_.input_dims.find(data_op_name);
if (iter != omg_context_.input_dims.end()) {
data_shape = ge::GeShape(iter->second);
GELOGD("Data op get shape from Context and update [%s] shape info", data_op_name.c_str());
GELOGD("Data op get shape from Context.");
} else {
data_shape = tensor->GetShape();
data_shape = tensor.GetShape();
GELOGD("Data op get shape from InputDesc in ge ir graph.");
}
// If user point input format, do work for all data ops; else do according to tensor_desc
auto data_format = omg_context_.format != domi::DOMI_TENSOR_ND ?
ge::TypeUtils::DomiFormatToFormat(omg_context_.format) : tensor->GetFormat();
ge::DataType data_type = tensor->GetDataType();
ge::TypeUtils::DomiFormatToFormat(omg_context_.format) : tensor.GetFormat();
ge::DataType data_type = tensor.GetDataType();
string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type);
GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str());

ge::GeTensor inputTensor;
ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type);
inputTensor.SetTensorDesc(desc);
int64_t index = 0;
if (AttrUtils::GetInt(op, ATTR_NAME_INDEX, index)) {
AttrUtils::SetInt(desc, ATTR_NAME_INDEX, index);
} else {
GELOGE(GRAPH_PARAM_INVALID, "Get attr name idx failed!");
return GRAPH_PARAM_INVALID;
}
inputs.emplace_back(inputTensor);
inputs.push_back(inputTensor);
}
}
std::sort(inputs.begin(), inputs.end(), [](ge::GeTensor a, ge::GeTensor b) {
int64_t data_idx_a = 0;
int64_t data_idx_b = 0;
AttrUtils::GetInt(a.MutableTensorDesc(), ATTR_NAME_INDEX, data_idx_a);
AttrUtils::GetInt(b.MutableTensorDesc(), ATTR_NAME_INDEX, data_idx_b);
return data_idx_a <= data_idx_b;
});
GELOGD("CreateInputsForIRBuild, inputs size: %zu", inputs.size());
return GRAPH_SUCCESS;
}
@@ -606,7 +590,7 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m
GELOGE(GRAPH_PARAM_INVALID, "input model is illegal");
return GRAPH_PARAM_INVALID;
}
return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast<void*>(model.data.get()),
return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast<void *>(model.data.get()),
static_cast<uint32_t>(model.length));
}

@@ -621,7 +605,7 @@ graphStatus aclgrphSaveModel(const char *output_file, const ModelBufferData &mod
return GRAPH_PARAM_INVALID;
}
std::string str_output_file = output_file;
return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void*>(model.data.get()),
return FileSaver::SaveToFile((str_output_file + ".om"), reinterpret_cast<void *>(model.data.get()),
static_cast<uint32_t>(model.length));
}



+ 25
- 7
ge/offline/CMakeLists.txt View File

@@ -74,22 +74,22 @@ target_link_libraries(atc PRIVATE
-ldl
)

############ atc.bin ############
add_executable(atc.bin ${SRC_LIST} ${PROTO_HDRS})
############ atc_atc.bin ############
add_executable(atc_atc.bin ${SRC_LIST} ${PROTO_HDRS})

target_compile_options(atc.bin PRIVATE
target_compile_options(atc_atc.bin PRIVATE
-Werror
-O2
-Wno-deprecated-declarations
)

target_compile_definitions(atc.bin PRIVATE
target_compile_definitions(atc_atc.bin PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
COMPILE_OMG_PACKAGE
google=ascend_private
)

target_include_directories(atc.bin PRIVATE
target_include_directories(atc_atc.bin PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${GE_CODE_DIR}
${GE_CODE_DIR}/ge
@@ -115,7 +115,7 @@ target_include_directories(atc.bin PRIVATE
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain
)

target_link_libraries(atc.bin PRIVATE
target_link_libraries(atc_atc.bin PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
ge_common
@@ -134,6 +134,11 @@ target_link_libraries(atc.bin PRIVATE
-ldl
)

set_target_properties(atc_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY atclib
)

############ fwk_atc.bin ############
add_executable(fwk_atc.bin ${SRC_LIST} ${PROTO_HDRS})

@@ -194,10 +199,23 @@ target_link_libraries(fwk_atc.bin PRIVATE
-ldl
)

set_target_properties(fwk_atc.bin PROPERTIES
OUTPUT_NAME atc.bin
RUNTIME_OUTPUT_DIRECTORY fwkacl
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS atc atc.bin fwk_atc.bin OPTIONAL
install(TARGETS atc OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

install(TARGETS atc_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/atclib
)

install(TARGETS fwk_atc.bin OPTIONAL
RUNTIME DESTINATION ${INSTALL_LIBRARY_DIR}/fwkacl
)

+ 7
- 6
ge/offline/atc View File

@@ -4,7 +4,12 @@
# Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
#-------------------------------------------------------------------

LOCAL_PATH=$(cd "$(dirname "$0")"; pwd)
real_path=$(readlink "$0")
if [ $? -eq 0 ]; then
LOCAL_PATH=$(cd "$(dirname "$real_path")"; pwd)
else
LOCAL_PATH=$(cd "$(dirname "$0")"; pwd)
fi
PKG_PATH=$(cd ${LOCAL_PATH}/..; pwd)
LIB_P="/lib64"
PYTHON_P="/python/site-packages"
@@ -13,8 +18,4 @@ PYTHON_PATH="${PKG_PATH}${PYTHON_P}"
export LD_LIBRARY_PATH="${LIB64_PATH}:${LD_LIBRARY_PATH}"
export PYTHONPATH="${PYTHON_PATH}:${PYTHONPATH}"

if [ -f "${PKG_PATH}/bin/atc.bin" ];then
${PKG_PATH}/bin/atc.bin/atc.bin $@
else
${PKG_PATH}/bin/atc.bin/fwk_atc.bin $@
fi
${PKG_PATH}/bin/atc.bin "$@"

+ 2
- 2
ge/offline/module.mk View File

@@ -56,7 +56,7 @@ include $(BUILD_HOST_EXECUTABLE)

include $(CLEAR_VARS)

LOCAL_MODULE := atc.bin
LOCAL_MODULE := atclib/atc.bin

LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private
@@ -109,7 +109,7 @@ include $(BUILD_HOST_EXECUTABLE)

include $(CLEAR_VARS)

LOCAL_MODULE := fwk_atc.bin
LOCAL_MODULE := fwkacl/atc.bin

LOCAL_CFLAGS += -Werror -Wno-deprecated-declarations
LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dgoogle=ascend_private


+ 20
- 0
ge/offline/single_op_parser.cc View File

@@ -27,6 +27,7 @@
#include "common/ge_inner_error_codes.h"
#include "framework/common/util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/operator_factory_impl.h"

@@ -176,6 +177,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange);
if (it != j.end()) {
@@ -189,9 +191,12 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
string type_str = j.at(kKeyType).get<string>();
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str);
it = j.find(kKeyOriginFormat);
if (it != j.end()) {
string origin_format_str = j.at(kKeyOriginFormat).get<string>();
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str);
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
}
auto tensor_name = j.find(kKeyName);
@@ -202,6 +207,9 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
if (!is_tensor_valid) {
desc.SetValidFlag(is_tensor_valid);
}
}

void from_json(const Json &j, SingleOpAttr &attr) {
@@ -305,6 +313,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {

int index = 0;
for (auto &tensor_desc : op_desc.input_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
@@ -317,6 +331,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {

index = 0;
for (auto &tensor_desc : op_desc.output_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}
if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});


+ 6
- 0
ge/offline/single_op_parser.h View File

@@ -28,6 +28,10 @@

namespace ge {
struct SingleOpTensorDesc {
public:
bool GetValidFlag() const { return is_valid_; }
void SetValidFlag(bool is_valid) { is_valid_ = is_valid; }
public:
std::string name;
std::vector<int64_t> dims;
std::vector<int64_t> ori_dims;
@@ -36,6 +40,8 @@ struct SingleOpTensorDesc {
ge::Format ori_format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name;
private:
bool is_valid_ = true;
};

struct SingleOpAttr {


+ 4
- 4
ge/opskernel_manager/ops_kernel_manager.cc View File

@@ -175,8 +175,8 @@ Status OpsKernelManager::ParsePluginOptions(const map<string, string> &options,
} else if (flag == 1) {
enable_flag = true;
} else {
GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(),
iter->second.c_str());
GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.",
plugin_name.c_str(), iter->second.c_str());
return GE_GRAPH_OPTIONS_INVALID;
}
} catch (std::invalid_argument &) {
@@ -188,8 +188,8 @@ Status OpsKernelManager::ParsePluginOptions(const map<string, string> &options,
iter->second.c_str());
return GE_GRAPH_OPTIONS_INVALID;
} catch (...) {
GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.", plugin_name.c_str(),
iter->second.c_str());
GELOGE(GE_GRAPH_OPTIONS_INVALID, "option_key:%s, its value %s is invalid, it must be 0 or 1.",
plugin_name.c_str(), iter->second.c_str());
return GE_GRAPH_OPTIONS_INVALID;
}
} else {


+ 2
- 1
ge/session/omg.cc View File

@@ -644,7 +644,8 @@ Status ParseOutNodes(const string &out_nodes) {
if (!domi::GetContext().user_out_nodes_top_vec.empty()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
{"--out_nodes", out_nodes, "is not all index or top_name"});
GELOGE(PARAM_INVALID, "This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
GELOGE(PARAM_INVALID,
"This out_nodes str must be all index or top_name, while the actual input is %s", out_nodes.c_str());
return PARAM_INVALID;
}
// stoi: The method may throw an exception: invalid_argument/out_of_range


+ 2
- 1
ge/single_op/single_op.cc View File

@@ -111,7 +111,8 @@ Status SingleOp::ValidateArgs(const std::vector<DataBuffer> &inputs, const std::

auto num_outputs = outputs.size();
if (num_outputs != output_sizes_.size()) {
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu", output_sizes_.size(), outputs.size());
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "output num mismatch. model expect %zu, but given %zu",
output_sizes_.size(), outputs.size());
return ACL_ERROR_GE_PARAM_INVALID;
}



+ 2
- 1
ge/single_op/single_op_model.cc View File

@@ -268,7 +268,8 @@ Status SingleOpModel::BuildTaskList(StreamResource *stream_resource, SingleOp &s
ParseArgTable(task, single_op);
single_op.tasks_.emplace_back(task);
} else {
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID, "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type());
GELOGE(ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID,
"Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type());
return ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID;
}
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) {


+ 2
- 1
ge/single_op/task/tbe_task_builder.cc View File

@@ -173,7 +173,8 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam

auto tbe_kernel = GetTbeKernel(op_desc_);
if (tbe_kernel == nullptr) {
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s", op_desc_->GetName().c_str());
GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "OP EXT ATTR NAME TBE_KERNEL not found. op = %s",
op_desc_->GetName().c_str());
return ACL_ERROR_GE_INTERNAL_ERROR;
}



+ 1
- 1
inc/framework/common/taskdown_common.h View File

@@ -21,7 +21,7 @@

namespace ge {

#define CC_FUSION_OP_MAX 32
const int CC_FUSION_OP_MAX = 32;

typedef enum tagCcStatus {
CC_STATUS_SUCCESS = 0, /**< succ */


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit d19c9c5c92f21a0335c18681dcceed44f3a54ddc
Subproject commit bd2cfdfa85a3d9dcbd7dc825f5759c7f8b3ffa9a

Loading…
Cancel
Save