From 167621141b2f5bfbddd117d20706aea8461d5940 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 29 Mar 2021 10:46:17 +0800 Subject: [PATCH] hccl ops with same parallel group can not be execute parallelly --- ge/hybrid/model/hybrid_model_builder.cc | 232 +++++++++++++++------ ge/hybrid/model/hybrid_model_builder.h | 9 +- ge/hybrid/model/node_item.cc | 4 + ge/hybrid/model/node_item.h | 2 + .../compiledsubgraph/known_node_executor.cc | 43 ++-- .../compiledsubgraph/known_node_executor.h | 8 +- 6 files changed, 208 insertions(+), 90 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f5cb5f7e..25dabd78 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n (void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); (void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); - new_node->node_id = node_index; - new_node->op_desc->SetId(node_index); - node_index += 1; + new_node->node_id = static_cast(new_node->op_desc->GetId()); NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || @@ -273,16 +271,16 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt // not care result, if no this attr, stand for the op does not need force infershape (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); GELOGD("node [%s] is need do infershape , flag is %d", - op_desc->GetName().c_str(), - node_item.is_need_force_infershape); + op_desc->GetName().c_str(), + node_item.is_need_force_infershape); return SUCCESS; } Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { - std::set dependent_input_nodes; + std::set dependent_for_shape_inference; + std::set dependent_for_execution; auto &ge_node = node_item.node; - bool is_hccl_op = - NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL; + bool is_hccl_op = node_item.IsHcclOp(); // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. // Wait for these parent nodes before execution. @@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s auto src_node_item = MutableNodeItem(src_node); GE_CHECK_NOTNULL(src_node_item); - if (is_hccl_op) { - GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); - node_item.has_observer = true; - for (auto &dst_node : ge_node->GetOutNodes()) { - if (dst_node == nullptr) { - continue; - } - - NodeItem *dst_node_item = nullptr; - GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item)); - dst_node_item->dependents_for_execution.emplace_back(ge_node); - } - } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { - GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - + if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) { + GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d", + ge_node->GetName().c_str(), + ge_node->GetType().c_str(), + src_node->GetName().c_str(), + src_node->GetType().c_str(), + static_cast(src_node_item->shape_inference_type)); src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); + dependent_for_execution.emplace(src_node); } if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { @@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); src_node_item->has_observer = true; - dependent_input_nodes.emplace(src_node); + dependent_for_shape_inference.emplace(src_node); } } // cond or branch need to be prepared before the execution of IF or CASE if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { - const auto &in_anchor = ge_node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(in_anchor); - const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_anchor); - auto src_node = peer_anchor->GetOwnerNode(); + auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input GE_CHECK_NOTNULL(src_node); auto src_node_item = MutableNodeItem(src_node); GE_CHECK_NOTNULL(src_node_item); - src_node_item->has_observer = true; - node_item.dependents_for_execution.emplace_back(src_node); + dependent_for_execution.emplace(src_node); GELOGD("[%s] Dependent added from %s for control op's cond/branch", node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); @@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s GE_CHECK_NOTNULL(src_node); auto src_node_item = MutableNodeItem(src_node); src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); - src_node_item->has_observer = true; - - dependent_input_nodes.emplace(src_node); + dependent_for_shape_inference.emplace(src_node); GELOGD("[%s] Dependent added from output of [%s:%d]", node_item.NodeName().c_str(), src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx()); } - for (const auto &dep_node : dependent_input_nodes) { + GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference)); + for (const auto &dep_node : dependent_for_shape_inference) { + auto src_node_item = MutableNodeItem(dep_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->has_observer = true; node_item.dependents_for_shape_inference.emplace_back(dep_node); } - GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item)); + for (const auto &dep_node : dependent_for_execution) { + auto src_node_item = MutableNodeItem(dep_node); + GE_CHECK_NOTNULL(src_node_item); + src_node_item->has_observer = true; + node_item.dependents_for_execution.emplace_back(dep_node); + } + return SUCCESS; } -Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { +Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies) { if (node_item.fused_subgraph == nullptr) { return SUCCESS; } @@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) { node_item.NodeName().c_str(), op_desc->GetName().c_str(), src_node_item->NodeName().c_str()); - src_node_item->has_observer = true; src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); - - auto &depends = node_item.dependents_for_shape_inference; - if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) { - depends.emplace_back(src_node); - GELOGD("[%s] Dependent added from output of [%s:%d]", - node_item.NodeName().c_str(), - src_node_item->NodeName().c_str(), - peer_out_anchor->GetIdx()); - } + dependencies.emplace(src_node); + GELOGD("[%s] Dependent added from output of [%s:%d]", + node_item.NodeName().c_str(), + src_node_item->NodeName().c_str(), + peer_out_anchor->GetIdx()); } return SUCCESS; @@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_DUMP(root_graph, "hybrid_merged_graph"); } + root_graph_ = root_graph; + // Reset node id by topological order across all subgraphs + int64_t index = 0; + for (const auto &node : root_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto parent_graph = node->GetOwnerComputeGraph(); + // No need to update nodes in known subgraph + if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) { + continue; + } + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + op_desc->SetId(index++); + } + GE_DUMP(root_graph, "hybrid_merged_graph"); GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); GELOGD("Done loading root graph successfully."); GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); @@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() { } } + GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops"); GELOGI("Done loading all subgraphs successfully."); return SUCCESS; } @@ -1075,25 +1072,41 @@ Status HybridModelBuilder::InitWeights() { return SUCCESS; } +Status HybridModelBuilder::LoadTask(NodeItem &node_item) { + auto &node_ptr = node_item.node; + GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); + auto load_ret = node_item.node_executor->LoadTask(hybrid_model_, + node_ptr, + node_item.kernel_task); + if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { + GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); + return load_ret; + } + + GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + return SUCCESS; +} + Status HybridModelBuilder::LoadTasks() { GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); + std::map> ordered_partitioned_calls; for (auto &it : hybrid_model_.node_items_) { auto &node_item = it.second; - auto &node_ptr = node_item->node; if (node_item->node_type == NETOUTPUT) { continue; } - - GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); - auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, - node_ptr, - node_item->kernel_task); - if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { - GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); - return load_ret; + if (node_item->node_type == PARTITIONEDCALL) { + ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get(); + continue; } + GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item)); + } - GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str()); + // HCCL operators need to be loaded in the same order across different processes + for (auto &it : ordered_partitioned_calls) { + for (auto &it2 : it.second) { + GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second)); + } } return SUCCESS; @@ -1626,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem auto temp_graph = MakeShared("temp"); GE_CHECK_NOTNULL(temp_graph); auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); + wrapper_op_desc->SetId(parent_node_item->node_id); GeModelPtr ge_model = subgraph_models_[subgraph_name]; GE_CHECK_NOTNULL(ge_model); hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); @@ -2011,5 +2025,93 @@ Status HybridModelBuilder::CheckAicpuOpList() { "Launch check aicpu op type failed."); return SUCCESS; } + +Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { + const auto &node = node_item->node; + auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); + if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { + std::string parallel_group; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); + parallel_group_to_nodes_[parallel_group].emplace(node_item); + std::set group{parallel_group}; + node_to_parallel_groups_[node_item].emplace(parallel_group); + } + } else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) { + std::set parallel_groups; + GELOGD("[%s] To collect parallel group for known-shaped subgraph", node_item->NodeName().c_str()); + for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { + GELOGD("[%s] Start to get parallel group from subgraph: %s", + node_item->NodeName().c_str(), + subgraph_name.c_str()); + auto subgraph = root_graph_->GetSubgraph(subgraph_name); + GE_CHECK_NOTNULL(subgraph); + for (const auto &sub_node : subgraph->GetAllNodes()) { + std::string parallel_group; + if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + GELOGD("[%s::%s] Got parallel group = %s", + subgraph_name.c_str(), + sub_node->GetName().c_str(), + parallel_group.c_str()); + parallel_groups.emplace(parallel_group); + } + } + } + + if (!parallel_groups.empty()) { + for (const auto ¶llel_group : parallel_groups) { + parallel_group_to_nodes_[parallel_group].emplace(node_item); + GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str()); + } + node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups)); + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::ParseDependentByParallelGroup() { + for (auto &it : hybrid_model_.node_items_) { + GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); + } + for (const auto &it : node_to_parallel_groups_) { + auto node_item = it.first; + auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); + for (const auto ¶llel_group : it.second) { + auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; + NodeItem *nearest_dep_node = nullptr; + int max_id = -1; + for (auto &dep_node : dependent_nodes) { + if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { + nearest_dep_node = dep_node; + max_id = dep_node->node_id; + } + } + + if (nearest_dep_node != nullptr) { + GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); + auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); + if (src_engine_type == dst_executor_type) { + GELOGD("No need to add dependency for nodes with same executor type"); + continue; + } + auto &deps = node_item->dependents_for_execution; + if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { + GELOGD("%s->%s Already has dependency, skip it", + nearest_dep_node->node->GetName().c_str(), + node_item->NodeName().c_str()); + continue; + } + nearest_dep_node->has_observer = true; + deps.emplace_back(nearest_dep_node->node); + GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", + parallel_group.c_str(), + nearest_dep_node->NodeName().c_str(), + node_item->NodeName().c_str()); + } + } + } + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 313d5ca6..a59a282a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -57,14 +57,17 @@ class HybridModelBuilder { Status ValidateParams(); Status LoadGraph(); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); + Status LoadTask(NodeItem &node_item); Status LoadTasks(); Status IdentifyVariableOutputs(NodeItem &node_item); Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); + Status CollectParallelGroups(NodeItem *node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); - Status ParseDependentForFusedSubgraph(NodeItem &node_item); + Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set &dependencies); + Status ParseDependentByParallelGroup(); Status IndexTaskDefs(); Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); Status IndexSpecialNodes(); @@ -97,12 +100,14 @@ class HybridModelBuilder { NodeItem *MutableNodeItem(const NodePtr &node); GeRootModelPtr ge_root_model_; + ComputeGraphPtr root_graph_; std::map subgraph_models_; std::map constant_op_nodes_; + std::map> parallel_group_to_nodes_; + std::map> node_to_parallel_groups_; HybridModel &hybrid_model_; std::map>> node_ref_inputs_; - int node_index = 0; RuntimeParam &runtime_param_; VarManager *var_manager_ = nullptr; diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index 805064be..06d654cf 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const { return ge::hybrid::IsControlOp(op_desc->GetType()); } +bool NodeItem::IsHcclOp() const { + return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL; +} + std::string NodeItem::DebugString() const { std::stringstream ss; ss << "Node: "; diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 631dbd9e..474a1da4 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -67,6 +67,8 @@ struct NodeItem { bool IsControlOp() const; + bool IsHcclOp() const; + void SetToDynamic(); std::string DebugString() const; diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index bb96c275..45882343 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) { Status KnownNodeTask::Init(TaskContext &context) { // allocate output mem GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); - - // init davinicmodel - if (!load_flag_) { - davinci_model_->InitRuntimeParams(); - GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); - } - // allocate mem base void *buffer = nullptr; if (davinci_model_->TotalMemSize() != 0) { @@ -129,23 +122,31 @@ Status KnownNodeTask::Init(TaskContext &context) { void *global_step = context.GetExecutionContext()->global_step; davinci_model_->SetKnownShapeGlobalStep(global_step); } - int32_t device_id = 0; - rtError_t rt_ret = rtGetDevice(&device_id); - if (rt_ret != RT_ERROR_NONE || device_id < 0) { - GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); - return RT_ERROR_TO_GE_STATUS(rt_ret); - } - davinci_model_->SetDeviceId(device_id); - GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed."); load_flag_ = true; - } else { - GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), - davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed."); } + GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(), + davinci_model_->Id(), davinci_model_->SubModelId()), + "KnownNodeTask::Init destroy aicpu kernel failed."); GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); return SUCCESS; } +Status KnownNodeTask::InitDavinciModel() { + GELOGD("[Init][Model] start"); + davinci_model_->InitRuntimeParams(); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed"); + int32_t device_id = 0; + GE_CHK_RT_RET(rtGetDevice(&device_id)); + davinci_model_->SetDeviceId(static_cast(device_id)); + GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model."); + GELOGD("[Init][Model] success"); + return SUCCESS; +} + +Status KnownNodeTask::DoInitDavinciModel() { + return davinci_model_->Init(); +} + Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); @@ -182,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); - task = MakeShared(davinci_model); - GE_CHECK_NOTNULL(task); + auto known_node_task = MakeShared(davinci_model); + GE_CHECK_NOTNULL(known_node_task); + GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel()); GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); + task = std::move(known_node_task); return SUCCESS; } diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h index 6e9740ad..5eed528a 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask { : davinci_model_(davinci_model) {} - ~KnownNodeTask() {} + ~KnownNodeTask() = default; Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; Status Init(TaskContext &context) override; + Status InitDavinciModel(); + + protected: + virtual Status DoInitDavinciModel(); private: std::shared_ptr davinci_model_ = nullptr; bool load_flag_ = false; @@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor { Status PrepareTask(NodeTask &task, TaskContext &context) const; Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; ~KnownNodeExecutor() {} - private: - std::shared_ptr davinci_model_ = nullptr; }; } // namespace hybrid } // namespace ge