Browse Source

add cond when delete identity

pull/1688/MERGE^2
zhou_lili 4 years ago
parent
commit
7788e433ed
3 changed files with 17 additions and 9 deletions
  1. +13
    -6
      ge/graph/load/model_manager/task_info/kernel_task_info.cc
  2. +2
    -1
      ge/graph/load/model_manager/task_info/kernel_task_info.h
  3. +2
    -2
      ge/graph/preprocess/multi_batch_options.cc

+ 13
- 6
ge/graph/load/model_manager/task_info/kernel_task_info.cc View File

@@ -129,6 +129,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci
ctx_.opIndex2[i] = context.origin_op_index(i);
}
ctx_.opCount = context.origin_op_index_size();
InitDumpFlag();
if (kernel_type_ == ccKernelType::TE) {
ctx_.opIndex = context.op_index();
uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data()));
@@ -660,7 +661,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
if (davinci_model_->IsKnownNode()) {
args_ = l2_buffer_on_ ? davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_)
: davinci_model_->GetCurrentArgsAddr(args_offset_);
InitDumpTask(offset);
InitDumpArgs(offset);
return SUCCESS;
}

@@ -726,7 +727,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne
return FAILED;
}
skt_dump_args_ = static_cast<char *>(args_) + offset;
InitDumpTask(offset);
InitDumpArgs(offset);

vector<void *> virtual_io_addrs; // use virtual address for zero copy key.
virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end());
@@ -1022,7 +1023,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k

if (davinci_model_->IsKnownNode()) {
args_ = davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_);
InitDumpTask(sizeof(aicpu::AicpuParamHead));
InitDumpArgs(sizeof(aicpu::AicpuParamHead));
return SUCCESS;
}
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam();
@@ -1063,7 +1064,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
op_desc->GetName().c_str(), op_desc->GetType().c_str(), args_size_, rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
InitDumpTask(sizeof(aicpu::AicpuParamHead));
InitDumpArgs(sizeof(aicpu::AicpuParamHead));

if (kernel_type_ == ccKernelType::CUST_AI_CPU) {
dump_flag_ |= RT_KERNEL_CUSTOM_AICPU;
@@ -1074,14 +1075,20 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k
return SUCCESS;
}

void KernelTaskInfo::InitDumpTask(uint32_t offset) {
void KernelTaskInfo::InitDumpFlag() {
if (davinci_model_->OpNeedDump(op_desc_->GetName())) {
GELOGD("Op %s need dump in task info", op_desc_->GetName().c_str());
GELOGD("Op %s init dump flag", op_desc_->GetName().c_str());
if (IsL1FusionOp(op_desc_)) {
dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG;
} else {
dump_flag_ = RT_KERNEL_DUMPFLAG;
}
}
}

void KernelTaskInfo::InitDumpArgs(uint32_t offset) {
if (davinci_model_->OpNeedDump(op_desc_->GetName())) {
GELOGD("Op %s need dump in task info", op_desc_->GetName().c_str());
dump_args_ = static_cast<char *>(args_) + offset;
}
if (davinci_model_->GetOpDugReg()) {


+ 2
- 1
ge/graph/load/model_manager/task_info/kernel_task_info.h View File

@@ -128,7 +128,8 @@ class KernelTaskInfo : public TaskInfo {
Status SuperKernelDistribute();
bool IsL1FusionOp(const OpDescPtr &op_desc);
void SetIoAddrs(const OpDescPtr &op_desc);
void InitDumpTask(uint32_t offset);
void InitDumpFlag();
void InitDumpArgs(uint32_t offset);
void SetContinuousArgs(uint32_t args_size, DavinciModel *davinci_model);
void SetNoncontinuousArgs(uint32_t args_size, DavinciModel *davinci_model);
Status CopyNoncontinuousArgs(uint16_t offset);


+ 2
- 2
ge/graph/preprocess/multi_batch_options.cc View File

@@ -335,9 +335,9 @@ Status DeleteIdentityInsertByAdapter(ComputeGraphPtr &graph) {
GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, continue);
auto dst_node = peer_in_anchor->GetOwnerNode();
GE_IF_BOOL_EXEC(dst_node == nullptr, continue);
if (dst_node->GetType() == IDENTITY) {
if (dst_node->GetType() == IDENTITY && dst_node->GetAllOutDataAnchors().empty()) {
GELOGI("Need to remove %s.", dst_node->GetName().c_str());
if (ge::GraphUtils::RemoveNodeWithoutRelink(graph, dst_node) != GRAPH_SUCCESS) {
if (GraphUtils::RemoveNodeWithoutRelink(graph, dst_node) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) from graph:%s failed",
dst_node->GetName().c_str(), dst_node->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "Remove Identity node %s failed.", dst_node->GetName().c_str());


Loading…
Cancel
Save