Browse Source

hccl ops with same parallel group can not be execute parallelly

tags/v1.2.0
chuxing 4 years ago
parent
commit
167621141b
6 changed files with 208 additions and 90 deletions
  1. +167
    -65
      ge/hybrid/model/hybrid_model_builder.cc
  2. +7
    -2
      ge/hybrid/model/hybrid_model_builder.h
  3. +4
    -0
      ge/hybrid/model/node_item.cc
  4. +2
    -0
      ge/hybrid/model/node_item.h
  5. +23
    -20
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  6. +5
    -3
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h

+ 167
- 65
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
(void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); (void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false);
(void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); (void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false);


new_node->node_id = node_index;
new_node->op_desc->SetId(node_index);
node_index += 1;
new_node->node_id = static_cast<int>(new_node->op_desc->GetId());
NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) || new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) ||
(executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) || (executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) ||
@@ -273,16 +271,16 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
// not care result, if no this attr, stand for the op does not need force infershape // not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d", GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
std::set<NodePtr> dependent_for_shape_inference;
std::set<NodePtr> dependent_for_execution;
auto &ge_node = node_item.node; auto &ge_node = node_item.node;
bool is_hccl_op =
NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL;
bool is_hccl_op = node_item.IsHcclOp();


// The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE.
// Wait for these parent nodes before execution. // Wait for these parent nodes before execution.
@@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
auto src_node_item = MutableNodeItem(src_node); auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item); GE_CHECK_NOTNULL(src_node_item);


if (is_hccl_op) {
GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
node_item.has_observer = true;
for (auto &dst_node : ge_node->GetOutNodes()) {
if (dst_node == nullptr) {
continue;
}

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item));
dst_node_item->dependents_for_execution.emplace_back(ge_node);
}
} else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) {
GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());

if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str(),
static_cast<int>(src_node_item->shape_inference_type));
src_node_item->has_observer = true; src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
} }


if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) {
@@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
src_node_item->NodeName().c_str()); src_node_item->NodeName().c_str());
src_node_item->has_observer = true; src_node_item->has_observer = true;
dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
} }
} }


// cond or branch need to be prepared before the execution of IF or CASE // cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_anchor);
auto src_node = peer_anchor->GetOwnerNode();
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input
GE_CHECK_NOTNULL(src_node); GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node); auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item); GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
GELOGD("[%s] Dependent added from %s for control op's cond/branch", GELOGD("[%s] Dependent added from %s for control op's cond/branch",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
src_node_item->NodeName().c_str()); src_node_item->NodeName().c_str());
@@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
GE_CHECK_NOTNULL(src_node); GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node); auto src_node_item = MutableNodeItem(src_node);
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
src_node_item->has_observer = true;

dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]", GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(), src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx()); peer_out_anchor->GetIdx());
} }


for (const auto &dep_node : dependent_input_nodes) {
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference));
for (const auto &dep_node : dependent_for_shape_inference) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_shape_inference.emplace_back(dep_node); node_item.dependents_for_shape_inference.emplace_back(dep_node);
} }


GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item));
for (const auto &dep_node : dependent_for_execution) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(dep_node);
}

return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies) {
if (node_item.fused_subgraph == nullptr) { if (node_item.fused_subgraph == nullptr) {
return SUCCESS; return SUCCESS;
} }
@@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
node_item.NodeName().c_str(), node_item.NodeName().c_str(),
op_desc->GetName().c_str(), op_desc->GetName().c_str(),
src_node_item->NodeName().c_str()); src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx()); src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());

auto &depends = node_item.dependents_for_shape_inference;
if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) {
depends.emplace_back(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
dependencies.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
} }


return SUCCESS; return SUCCESS;
@@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(), root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize()); root_graph->GetAllNodesSize());
GE_DUMP(root_graph, "hybrid_merged_graph");
} }


root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
auto parent_graph = node->GetOwnerComputeGraph();
// No need to update nodes in known subgraph
if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
op_desc->SetId(index++);
}
GE_DUMP(root_graph, "hybrid_merged_graph");
GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph.");
GELOGD("Done loading root graph successfully."); GELOGD("Done loading root graph successfully.");
GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph");
@@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() {
} }
} }


GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops");
GELOGI("Done loading all subgraphs successfully."); GELOGI("Done loading all subgraphs successfully.");
return SUCCESS; return SUCCESS;
} }
@@ -1075,25 +1072,41 @@ Status HybridModelBuilder::InitWeights() {
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::LoadTask(NodeItem &node_item) {
auto &node_ptr = node_item.node;
GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item.node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item.kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
}

GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
return SUCCESS;
}

Status HybridModelBuilder::LoadTasks() { Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed."); GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed.");
std::map<int, std::map<std::string, NodeItem *>> ordered_partitioned_calls;
for (auto &it : hybrid_model_.node_items_) { for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second; auto &node_item = it.second;
auto &node_ptr = node_item->node;
if (node_item->node_type == NETOUTPUT) { if (node_item->node_type == NETOUTPUT) {
continue; continue;
} }

GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item->node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item->kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
if (node_item->node_type == PARTITIONEDCALL) {
ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get();
continue;
} }
GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item));
}


GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
// HCCL operators need to be loaded in the same order across different processes
for (auto &it : ordered_partitioned_calls) {
for (auto &it2 : it.second) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second));
}
} }


return SUCCESS; return SUCCESS;
@@ -1626,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem
auto temp_graph = MakeShared<ComputeGraph>("temp"); auto temp_graph = MakeShared<ComputeGraph>("temp");
GE_CHECK_NOTNULL(temp_graph); GE_CHECK_NOTNULL(temp_graph);
auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); auto wrapper_node = temp_graph->AddNode(wrapper_op_desc);
wrapper_op_desc->SetId(parent_node_item->node_id);
GeModelPtr ge_model = subgraph_models_[subgraph_name]; GeModelPtr ge_model = subgraph_models_[subgraph_name];
GE_CHECK_NOTNULL(ge_model); GE_CHECK_NOTNULL(ge_model);
hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model);
@@ -2011,5 +2025,93 @@ Status HybridModelBuilder::CheckAicpuOpList() {
"Launch check aicpu op type failed."); "Launch check aicpu op type failed.");
return SUCCESS; return SUCCESS;
} }

Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
const auto &node = node_item->node;
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) {
std::string parallel_group;
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str());
parallel_group_to_nodes_[parallel_group].emplace(node_item);
std::set<std::string> group{parallel_group};
node_to_parallel_groups_[node_item].emplace(parallel_group);
}
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) {
std::set<std::string> parallel_groups;
GELOGD("[%s] To collect parallel group for known-shaped subgraph", node_item->NodeName().c_str());
for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;
if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s::%s] Got parallel group = %s",
subgraph_name.c_str(),
sub_node->GetName().c_str(),
parallel_group.c_str());
parallel_groups.emplace(parallel_group);
}
}
}

if (!parallel_groups.empty()) {
for (const auto &parallel_group : parallel_groups) {
parallel_group_to_nodes_[parallel_group].emplace(node_item);
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str());
}
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups));
}
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (auto &it : hybrid_model_.node_items_) {
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get()));
}
for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first;
auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) {
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr;
int max_id = -1;
for (auto &dep_node : dependent_nodes) {
if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) {
nearest_dep_node = dep_node;
max_id = dep_node->node_id;
}
}

if (nearest_dep_node != nullptr) {
GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str());
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node);
if (src_engine_type == dst_executor_type) {
GELOGD("No need to add dependency for nodes with same executor type");
continue;
}
auto &deps = node_item->dependents_for_execution;
if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) {
GELOGD("%s->%s Already has dependency, skip it",
nearest_dep_node->node->GetName().c_str(),
node_item->NodeName().c_str());
continue;
}
nearest_dep_node->has_observer = true;
deps.emplace_back(nearest_dep_node->node);
GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
}
}
}
return SUCCESS;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 7
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -57,14 +57,17 @@ class HybridModelBuilder {
Status ValidateParams(); Status ValidateParams();
Status LoadGraph(); Status LoadGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);
Status LoadTasks(); Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item); Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item); Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status CollectParallelGroups(NodeItem *node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies);
Status ParseDependentByParallelGroup();
Status IndexTaskDefs(); Status IndexTaskDefs();
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model); Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model);
Status IndexSpecialNodes(); Status IndexSpecialNodes();
@@ -97,12 +100,14 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node); NodeItem *MutableNodeItem(const NodePtr &node);


GeRootModelPtr ge_root_model_; GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_; std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_; std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;


HybridModel &hybrid_model_; HybridModel &hybrid_model_;
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_; std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_;
int node_index = 0;


RuntimeParam &runtime_param_; RuntimeParam &runtime_param_;
VarManager *var_manager_ = nullptr; VarManager *var_manager_ = nullptr;


+ 4
- 0
ge/hybrid/model/node_item.cc View File

@@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const {
return ge::hybrid::IsControlOp(op_desc->GetType()); return ge::hybrid::IsControlOp(op_desc->GetType());
} }


bool NodeItem::IsHcclOp() const {
return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
}

std::string NodeItem::DebugString() const { std::string NodeItem::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "Node: "; ss << "Node: ";


+ 2
- 0
ge/hybrid/model/node_item.h View File

@@ -67,6 +67,8 @@ struct NodeItem {


bool IsControlOp() const; bool IsControlOp() const;


bool IsHcclOp() const;

void SetToDynamic(); void SetToDynamic();


std::string DebugString() const; std::string DebugString() const;


+ 23
- 20
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) {
Status KnownNodeTask::Init(TaskContext &context) { Status KnownNodeTask::Init(TaskContext &context) {
// allocate output mem // allocate output mem
GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed.");

// init davinicmodel
if (!load_flag_) {
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed.");
}

// allocate mem base // allocate mem base
void *buffer = nullptr; void *buffer = nullptr;
if (davinci_model_->TotalMemSize() != 0) { if (davinci_model_->TotalMemSize() != 0) {
@@ -129,23 +122,31 @@ Status KnownNodeTask::Init(TaskContext &context) {
void *global_step = context.GetExecutionContext()->global_step; void *global_step = context.GetExecutionContext()->global_step;
davinci_model_->SetKnownShapeGlobalStep(global_step); davinci_model_->SetKnownShapeGlobalStep(global_step);
} }
int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE || device_id < 0) {
GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
davinci_model_->SetDeviceId(device_id);
GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed.");
load_flag_ = true; load_flag_ = true;
} else {
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed.");
} }
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()),
"KnownNodeTask::Init destroy aicpu kernel failed.");
GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName()); GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName());
return SUCCESS; return SUCCESS;
} }


Status KnownNodeTask::InitDavinciModel() {
GELOGD("[Init][Model] start");
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed");
int32_t device_id = 0;
GE_CHK_RT_RET(rtGetDevice(&device_id));
davinci_model_->SetDeviceId(static_cast<uint32_t>(device_id));
GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model.");
GELOGD("[Init][Model] success");
return SUCCESS;
}

Status KnownNodeTask::DoInitDavinciModel() {
return davinci_model_->Init();
}

Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const {
GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start");
@@ -182,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node


GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed."); GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed.");


task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(task);
auto known_node_task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(known_node_task);
GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel());
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str()); GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str());
task = std::move(known_node_task);
return SUCCESS; return SUCCESS;
} }




+ 5
- 3
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h View File

@@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask {
: davinci_model_(davinci_model) : davinci_model_(davinci_model)
{} {}


~KnownNodeTask() {}
~KnownNodeTask() = default;


Status UpdateArgs(TaskContext &context) override; Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
Status Init(TaskContext &context) override; Status Init(TaskContext &context) override;
Status InitDavinciModel();

protected:
virtual Status DoInitDavinciModel();
private: private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr; std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
bool load_flag_ = false; bool load_flag_ = false;
@@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor {
Status PrepareTask(NodeTask &task, TaskContext &context) const; Status PrepareTask(NodeTask &task, TaskContext &context) const;
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {} ~KnownNodeExecutor() {}
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


Loading…
Cancel
Save