Browse Source

Pre Merge pull request !344 from lixiwen1/checkAccuracy

pull/344/MERGE
lixiwen1 Gitee 4 years ago
parent
commit
da35698746
4 changed files with 37 additions and 22 deletions
  1. +14
    -4
      ge/graph/load/new_model_manager/davinci_model.cc
  2. +19
    -15
      ge/graph/passes/compile_nodes_pass.cc
  3. +3
    -3
      ge/graph/passes/flow_ctrl_pass.cc
  4. +1
    -0
      third_party/fwkacllib/inc/runtime/base.h

+ 14
- 4
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -2804,9 +2804,14 @@ void *DavinciModel::Run(DavinciModel *model) {

GELOGI("rtStreamSynchronize start.");
rt_ret = rtStreamSynchronize(model->rt_model_stream_);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, seq_end_flag, data_wrapper->GetOutput());
continue); // [No need to check value]
if (rt_ret == RT_ERROR_MODEL_ABORT_NORMAL) {
GELOGW("rtStreamSynchronize get result : RT_ERROR_MODEL_ABORT_NORMAL, abort model normal");
} else {
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, seq_end_flag, data_wrapper->GetOutput());
continue); // [No need to check value]
}

GELOGI("rtStreamSynchronize end.");
(void)ProfilingManager::Instance().StopProfiling(); // just profiling, no need to check value
}
@@ -2827,12 +2832,17 @@ void *DavinciModel::Run(DavinciModel *model) {
if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) {
seq_end_flag = true;
}
GE_IF_BOOL_EXEC(
if (rt_ret == RT_ERROR_MODEL_ABORT_NORMAL) {
GELOGW("rtStreamSynchronize get result : RT_ERROR_MODEL_ABORT_NORMAL, abort model normal");
} else {
GE_IF_BOOL_EXEC(
rt_ret != RT_ERROR_NONE, rslt_flg = false; GELOGI("seq_end_flg: %d", seq_end_flag);
(void)model->ReturnResult(current_data.index, false, seq_end_flag,
data_wrapper->GetOutput()); // [No need to check value]
CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
continue);
}

GELOGI("rtStreamSynchronize end.");
GE_IF_BOOL_EXEC(model->is_first_execute_,
GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"));


+ 19
- 15
ge/graph/passes/compile_nodes_pass.cc View File

@@ -35,7 +35,7 @@ const char *const kAICPUKernelLibName = "aicpu_tf_kernel";
namespace ge {
graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) {
GE_TIMESTAMP_START(CompileNodesPass);
GELOGI("[CompileNodesPass]: optimize begin.");
GELOGD("[CompileNodesPass]: optimize begin.");
if (graph == nullptr) {
return GRAPH_SUCCESS;
}
@@ -81,7 +81,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) {
GELOGE(result, "Compile op failed.");
return result;
}
GELOGI("[CompileNodesPass]: Optimize success.");
GELOGD("[CompileNodesPass]: Optimize success.");
GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass");
return GRAPH_SUCCESS;
}
@@ -111,20 +111,24 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std:
}
// begin accuracy supported check
if (!CheckAccuracySupport(kernel_info, instance, op_desc)) {
// if check accuracy support failed , try to go to aicpu engine
string aicpu_kernel_lib_name = kAICPUKernelLibName;
OpsKernelInfoStorePtr aicpu_kernel_info =
instance->OpsKernelManagerObj().GetOpsKernelInfoStore(aicpu_kernel_lib_name);
if (aicpu_kernel_info == nullptr) {
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get aicpu kernel info store failed.");
return ge::GE_GRAPH_PARAM_NULLPTR;
}
if (!CheckAccuracySupport(aicpu_kernel_info, instance, op_desc)) {
GELOGE(GRAPH_FAILED, "AICPU engine does not support node:%s, type:%s , get kernel lib failed.",
node->GetName().c_str(), op_desc->GetType().c_str());
return GRAPH_FAILED;
// if check accuracy support failed , try to go to other engine.
string kernel_name_origin = kernel_lib_name;
OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
for (auto it = kernel_map.begin(); it != kernel_map.end(); ++it) {
string tmp_kernel_name = it->first;
if (tmp_kernel_name == kernel_name_origin) {
continue;
}
OpsKernelInfoStorePtr tmp_kernel_info = it->second;
if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) {
kernel_lib_name = tmp_kernel_name;
return GRAPH_SUCCESS;
}
}
kernel_lib_name = kAICPUKernelLibName;
GELOGE(GRAPH_FAILED, "Cannot find engine support node:%s, type:%s , get kernel lib failed.",
node->GetName().c_str(), op_desc->GetType().c_str());
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}


+ 3
- 3
ge/graph/passes/flow_ctrl_pass.cc View File

@@ -357,9 +357,9 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c
return FAILED;
}
GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed");
GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true),
DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); return FAILED);
GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, switch_node->GetName()),
"set switch branch node label failed");
string model_exit_name = switch_node->GetName() + "_ModelExit";
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, { model_exit_name }), "set active label list failed");



+ 1
- 0
third_party/fwkacllib/inc/runtime/base.h View File

@@ -103,6 +103,7 @@ typedef enum tagRtError {
RT_ERROR_MODEL_EXIT,
RT_ERROR_MODEL_EXIT_STREAM_UNBIND,
RT_ERROR_MODEL_EXIT_ID,
RT_ERROR_MODEL_ABORT_NORMAL,

RT_ERROR_EVENT_BASE = 0x07050000,
RT_ERROR_EVENT_NULL,


Loading…
Cancel
Save