@@ -259,8 +259,16 @@ ShapeFuture::ShapeFuture(NodeState *src_node, | |||||
NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | ||||
: node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | ||||
this->op_desc_ = node_item.node->GetOpDesc(); | this->op_desc_ = node_item.node->GetOpDesc(); | ||||
} | |||||
Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) { | |||||
GE_CHECK_NOTNULL(frame_state); | |||||
group_ = group; | |||||
frame_state_ = frame_state; | |||||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | auto unique_task_context = TaskContext::Create(this, subgraph_context_); | ||||
GE_CHECK_NOTNULL(unique_task_context); | |||||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | ||||
return SUCCESS; | |||||
} | } | ||||
Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | ||||
@@ -350,6 +358,7 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
switch_index_ = -1; | switch_index_ = -1; | ||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
auto unique_task_context = TaskContext::Create(this, subgraph_context_); | auto unique_task_context = TaskContext::Create(this, subgraph_context_); | ||||
GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); | |||||
task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | ||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
@@ -100,6 +100,8 @@ struct NodeState { | |||||
NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | ||||
~NodeState() = default; | ~NodeState() = default; | ||||
Status Init(int group, const shared_ptr<FrameState> &frame_state); | |||||
OpDesc *GetOpDesc() const { | OpDesc *GetOpDesc() const { | ||||
return op_desc_.get(); | return op_desc_.get(); | ||||
} | } | ||||
@@ -152,18 +154,10 @@ struct NodeState { | |||||
return merge_index_; | return merge_index_; | ||||
} | } | ||||
void SetGroup(int group) { | |||||
group_ = group; | |||||
} | |||||
int GetGroup() const { | int GetGroup() const { | ||||
return group_; | return group_; | ||||
} | } | ||||
void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||||
frame_state_ = frame_state; | |||||
} | |||||
const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
return kernel_task_; | return kernel_task_; | ||||
} | } | ||||
@@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return CreateNodeState(node_item); | |||||
} | |||||
NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { | |||||
GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | ||||
if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | ||||
REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
if (node_state == nullptr) { | |||||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||||
node_state->SetGroup(group_); | |||||
(void)guard; | |||||
} | |||||
do { | |||||
if (node_state == nullptr) { | |||||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); | |||||
break; | |||||
} | |||||
(void)guard; | |||||
} | |||||
} while (0); | |||||
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | ||||
if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | ||||
REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | ||||
@@ -51,6 +51,7 @@ class SubgraphContext { | |||||
void NodeDone(const NodePtr &node); | void NodeDone(const NodePtr &node); | ||||
private: | private: | ||||
NodeStatePtr CreateNodeState(const NodeItem *node_item); | |||||
FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | ||||
friend class TaskContext; | friend class TaskContext; | ||||
const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||