diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 42e08811..468c84e6 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -259,8 +259,16 @@ ShapeFuture::ShapeFuture(NodeState *src_node, NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { this->op_desc_ = node_item.node->GetOpDesc(); +} + +Status NodeState::Init(int group, const shared_ptr &frame_state) { + GE_CHECK_NOTNULL(frame_state); + group_ = group; + frame_state_ = frame_state; auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL(unique_task_context); task_context_ = std::shared_ptr(unique_task_context.release()); + return SUCCESS; } Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { @@ -350,6 +358,7 @@ void NodeState::ResetContext(uint64_t iteration) { switch_index_ = -1; subgraph_context_->ResetContext(node_item_->node); auto unique_task_context = TaskContext::Create(this, subgraph_context_); + GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); task_context_ = std::shared_ptr(unique_task_context.release()); data_scheduled_ = static_cast(node_item_->root_data_.size()); diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index 72e2b90e..85f9e4c3 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -100,6 +100,8 @@ struct NodeState { NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; + Status Init(int group, const shared_ptr &frame_state); + OpDesc *GetOpDesc() const { return op_desc_.get(); } @@ -152,18 +154,10 @@ struct NodeState { return merge_index_; } - void SetGroup(int group) { - group_ = group; - } - int GetGroup() const { return group_; } - void SetFrameState(const shared_ptr &frame_state) { - frame_state_ = frame_state; - } - const shared_ptr &GetKernelTask() const { return kernel_task_; } diff --git a/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc index 41ada9af..5e97a9a2 100644 --- a/ge/hybrid/executor/subgraph_context.cc +++ b/ge/hybrid/executor/subgraph_context.cc @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { return nullptr; } + return CreateNodeState(node_item); +} + +NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { GELOGD("[%s] lock for write", node_item->NodeName().c_str()); if (mmRWLockWRLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); return nullptr; } + auto &node_state = node_states_[node_item]; - if (node_state == nullptr) { - const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); - node_state.reset(new(std::nothrow)NodeState(*node_item, this)); - node_state->SetFrameState(GetOrCreateFrameState(*node_item)); - node_state->SetGroup(group_); - (void)guard; - } + do { + if (node_state == nullptr) { + const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); + node_state.reset(new(std::nothrow)NodeState(*node_item, this)); + if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); + REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); + break; + } + (void)guard; + } + } while (0); + GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); if (mmWRLockUnLock(&rw_lock_) != EN_OK) { REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); diff --git a/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h index d11d00d7..023be981 100644 --- a/ge/hybrid/executor/subgraph_context.h +++ b/ge/hybrid/executor/subgraph_context.h @@ -51,6 +51,7 @@ class SubgraphContext { void NodeDone(const NodePtr &node); private: + NodeStatePtr CreateNodeState(const NodeItem *node_item); FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock friend class TaskContext; const GraphItem *graph_item_;