diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index d7bdbdae..875cb396 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -663,6 +663,20 @@ namespace { } return SUCCESS; } + + Status CheckNoAicore(const ComputeGraphPtr &graph, bool &no_aicore) { + no_aicore = true; + for (const auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetOpEngineName() == kAIcoreEngine) { + no_aicore = false; + return SUCCESS; + } + } + return SUCCESS; + } } Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, @@ -745,7 +759,9 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in bool all_shape = false; (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); - if (all_shape) { + bool no_aicore = true; + GE_CHK_STATUS_RET_NOLOG(CheckNoAicore(root_graph, no_aicore)); + if (all_shape && no_aicore) { GELOGD("Get aicpu all_shape kernel!"); vector inputs_dynamic; vector outputs_dynamic; diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 31b51e61..840a7183 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -44,20 +44,46 @@ namespace ge { namespace { const size_t kDataOutputNum = 1; +Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { + auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + GE_CHECK_NOTNULL(comp_graph); + for (const auto &node : comp_graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &depends = op_desc->GetOpInferDepends(); + if (!depends.empty()) { + flag = true; + return SUCCESS; + } + } + return SUCCESS; +} -bool NeedHybridModel(GeModelPtr &ge_model) { +Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { + bool infer_depend_flag = false; + GE_CHK_STATUS_RET_NOLOG(IfInferDepend(ge_model, infer_depend_flag)); auto tasks = ge_model->GetModelTaskDefPtr()->task(); int32_t kernel_task_num = 0; for (int i = 0; i < tasks.size(); ++i) { auto task_type = static_cast(tasks[i].type()); if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) { - kernel_task_num++; - if (kernel_task_num > 1) { - return true; + const auto &context = task_type == RT_MODEL_TASK_KERNEL ? task_def.kernel().context() : + task_def.kernel_with_handle().context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == ccKernelType::TE) { + if (infer_depend_flag) { + flag = true; + return SUCCESS; + } + kernel_task_num++; + if (kernel_task_num > 1) { + flag = true; + return SUCCESS; + } } } } - return false; + return SUCCESS; } } // namespace @@ -504,7 +530,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); - if (NeedHybridModel(ge_model)) { + bool need_hybrid_model = false; + GE_CHK_STATUS_RET_NOLOG(NeedHybridModel(ge_model, need_hybrid_model)); + if (need_hybrid_model) { GELOGD("Build single op HybridModel."); GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); auto root_model = model_helper_.GetGeRootModel();